In [1]:
import pandas as pd
import numpy as np
from skimage import io, transform

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
from torchvision import transforms, utils, models
import torchvision as tv
from torchvision.transforms import v2

import lightning as L
import torchmetrics as tm
import torch.nn.functional as F

from IPython.display import clear_output

import os

In [2]:
L.seed_everything(42)

[rank: 0] Seed set to 42


42

### Creating the dataset and dataloader from `intel-image-classification`

In [22]:
# convert PIL image into torch Tensor then does specified transforms from docs: 
# https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
ds_transforms = v2.Compose([
    v2.ToImage(),
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [23]:
def get_dataset(name, **kwargs):
    path = "intel-image-classification/" + name
    ds = tv.datasets.ImageFolder(path, transform=ds_transforms)
    if 'rand_fraction' in kwargs:
        sp_frac = kwargs['rand_fraction']
        if type(sp_frac) is not float or not (0 < sp_frac < 1):
            raise ValueError(f"Invalid `rand_fraction` argument: [{sp_frac}]. Should be a float, s.t. 0.0 < x < 1.0")
        ds, _ = random_split(ds, [sp_frac, 1 - sp_frac], generator=torch.Generator().manual_seed(42))
    return ds

def get_train(aug_name=None, **kwargs):
    if aug_name:
        return get_dataset("seg_train/seg_train_aug/" + aug_name, **kwargs)
    return get_dataset("seg_train/seg_train", **kwargs)

def get_test(**kwargs):
    return get_dataset("seg_test/seg_test", **kwargs)

In [40]:
TRAIN_FRACTION = 1/100

In [41]:
augmentation = "canny2"
train_dataset = ConcatDataset((get_train(rand_fraction=TRAIN_FRACTION), get_train(augmentation, rand_fraction=TRAIN_FRACTION)))
test_dataset = get_test()

In [42]:
len(train_dataset)

282

In [43]:
test_dataset

Dataset ImageFolder
    Number of datapoints: 3000
    Root location: intel-image-classification/seg_test/seg_test
    StandardTransform
Transform: Compose(
                 ToImage()
                 Resize(size=[256], interpolation=InterpolationMode.BILINEAR, antialias=True)
                 CenterCrop(size=(224, 224))
                 ToDtype(scale=True)
                 Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
           )

In [44]:
test_dataset, val_dataset = random_split(test_dataset, [.5, .5])

### Creating models & Training

In [45]:
feature_extractors = {}

In [46]:
# create resnet50 feature extractor
resnet50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet50.eval()
resnet50_backbone = list(resnet50.children())[:-1]
resnet50_feat_extractor = nn.Sequential(*resnet50_backbone)
feature_extractors['resnet50'] = resnet50_feat_extractor
clear_output()

In [47]:
class IntelClassifier(L.LightningModule):
    def __init__(self, feature_extractor_name, output_features, num_classes, classifier=None, optimizer=torch.optim.Adam, lr=1e-2):
        super().__init__()
        self.save_hyperparameters()
        if feature_extractor_name not in feature_extractors:
            raise ValueError(f"`feature_extractor_name` argument is invalid (should be one of {list(feature_extractors.keys())})")
        self.feature_extractor = feature_extractors[feature_extractor_name]
        self.classifier = classifier if classifier else nn.Sequential( # classifier layers after the feature extraction
            nn.Linear(output_features, 512),
            nn.LeakyReLU(),
            nn.Dropout(.2),
            nn.Linear(512, num_classes)
        )
        self.accuracy = tm.classification.Accuracy(task="multiclass", num_classes=num_classes)
        self.optimizer = optimizer
        self.lr = lr

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            features = self.feature_extractor(x).flatten(1)
        return self.classifier(features)

    def _batch_step(self, batch, batch_kind):
        if batch_kind == 'train':
            self.classifier.train()
        else:
            self.classifier.eval()
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = self.accuracy(y_hat, y)
        # logging onto tensorboard
        self.log(f"{batch_kind}_loss", loss, prog_bar=True)
        self.log(f"{batch_kind}_acc_f1", acc, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self._batch_step(batch, 'train')

    def validation_step(self, batch, batch_idx):
        return self._batch_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        return self._batch_step(batch, 'test')

    def predict_step(self, batch, batch_idx):
        self.eval()
        x, _ = batch
        return self(x)
    
    def configure_optimizers(self):
        optimizer = self.optimizer(self.parameters(), lr=self.lr)
        return optimizer

In [48]:
resnet50_model = IntelClassifier('resnet50', 2048, 6, lr=1e-5)

In [49]:
BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, BATCH_SIZE, num_workers=4)

In [50]:
# set up loggers
tb_logger = L.pytorch.loggers.TensorBoardLogger(save_dir='')
csv_logger = L.pytorch.loggers.CSVLogger(save_dir='')

In [51]:
trainer = L.Trainer(logger=[tb_logger, csv_logger], callbacks=[L.pytorch.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=5)], max_epochs=50)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [52]:
torch.set_float32_matmul_precision('high')

In [53]:
trainer.fit(resnet50_model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type               | Params
---------------------------------------------------------
0 | feature_extractor | Sequential         | 23.5 M
1 | classifier        | Sequential         | 1.1 M 
2 | accuracy          | MulticlassAccuracy | 0     
---------------------------------------------------------
24.6 M    Trainable params
0         Non-trainable params
24.6 M    Total params
98.241    Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |                                                                         | 0/? [00:00<?, ?it…

/storage/ice1/5/4/rso31/miniforge3/envs/cv_env/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

Validation: |                                                                              | 0/? [00:00<?, ?it…

`Trainer.fit` stopped: `max_epochs=50` reached.


In [54]:
CKPT_PATH = 'resnet50-intel-canny2-1_100train.ckpt'

In [55]:
trainer.save_checkpoint(CKPT_PATH)