In [1]:
#https://github.com/Lightning-AI/lightning/blob/master/examples/pl_domain_templates/computer_vision_fine_tuning.py

In [13]:
from pathlib import Path

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, Accuracy, F1Score
from torchvision import transforms
from torchvision.datasets import ImageFolder

import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.utilities import rank_zero_info

import timm

import numpy as np

In [15]:
class ImageDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size = 32, num_workers = 0, size = 224):
        super().__init__()
        
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.size = size

    def setup(self, stage):
        
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.RandomResizedCrop((self.size, self.size)),
                transforms.RandomRotation(90, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.RandomHorizontalFlip(p=0.25),
                transforms.RandomVerticalFlip(p=0.25),
                transforms.RandomAutocontrast(p=0.25),
                transforms.RandomPerspective(p=0.25),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )


        self.valid_transform = transforms.Compose(
                    [
                        transforms.ToTensor(),
                        transforms.Resize(size = self.size), 
                        transforms.CenterCrop(size = self.size),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                    ]
        )

        self.train_ds = ImageFolder(root=path/"train", transform=self.train_transform)
        self.valid_ds = ImageFolder(root=path/"valid", transform=self.valid_transform)
        
        self.class_to_idx = self.train_ds.class_to_idx
        self.classes = self.train_ds.classes
        self.targets = self.train_ds.targets
        self.num_classes = len(self.classes)
        
        targets_np = np.array(self.targets)
        class_count = np.array([len(np.where(targets_np == t)[0]) for t in np.unique(targets_np)])
        self.class_weight = torch.tensor(1.0 / class_count).float()
        self.sample_weight = torch.tensor([self.class_weight[t] for t in targets_np])
        self.sampler = WeightedRandomSampler(self.sample_weight, len(self.sample_weight))
        
    def train_dataloader(self):
        
        return DataLoader(self.train_ds, 
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers,
                          sampler = self.sampler)

    def val_dataloader(self):
        
        return DataLoader(self.valid_ds, 
                          batch_size=self.batch_size,
                          num_workers=self.num_workers,
                          shuffle=False)

In [37]:
class TransferLearningModel(LightningModule):
    def __init__(self, backbone = "resnet18", num_classes = 10, milestones = (5, 10),
                 class_weight = None, lr = 1e-3, lr_scheduler_gamma = 1e-1):
        
        super().__init__()
        
        self.backbone = backbone
        self.lr = lr
        self.lr_scheduler_gamma = lr_scheduler_gamma
        self.num_classes = num_classes
        self.milestones = milestones

        self.model = timm.create_model(self.backbone, pretrained=True, num_classes=self.num_classes)
        
        self.loss = nn.CrossEntropyLoss(weight=class_weight)
        self.softmax = nn.Softmax(dim=1)

        self.train_acc = Accuracy(task='multiclass', num_classes=self.num_classes, top_k=1)
        self.val_acc = Accuracy(task='multiclass', num_classes=self.num_classes, top_k=1)

        #self.f1 = F1Score(num_classes=self.num_classes)
        
        self.save_hyperparameters()

    def forward(self, x):

        x = self.model(x)

        return x

    def training_step(self, batch, batch_idx):

        x, y = batch
        y_logits = self.forward(x)

        train_loss = self.loss(y_logits, y)

        y_scores = self.softmax(y_logits)

        self.log("train_loss", train_loss, prog_bar=True)
        self.log("train_acc", self.train_acc(y_scores, y), prog_bar=True)
        #self.log("train_f1", self.f1(y_scores, y.int()))

        return train_loss

    def validation_step(self, batch, batch_idx):

        x, y = batch
        y_logits = self.forward(x)
        
        y_scores = self.softmax(y_logits)

        self.log("val_loss", self.loss(y_logits, y), prog_bar=True)
        self.log("val_acc", self.val_acc(y_scores, y), prog_bar=True)
        #self.log("val_f1", self.f1(y_scores, y.int()))

    def configure_optimizers(self):
        
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma,
                                verbose = True)
        
        return [optimizer], [scheduler]


In [38]:
dm = ImageDataModule(data_dir = path, batch_size = 32, num_workers = 8)
dm.setup(stage="fit")

model = TransferLearningModel(backbone = "resnet18", 
                              num_classes = dm.num_classes, 
                              class_weight= None, 
                              lr=1e-3, 
                              milestones=(5, 10)) #dm.class_weight

trainer = pl.Trainer(accelerator='gpu', 
                     devices=1,
                     max_epochs=15, 
                     precision=16)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [39]:
trainer.fit(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | ResNet             | 11.2 M
1 | loss      | CrossEntropyLoss   | 0     
2 | softmax   | Softmax            | 0     
3 | train_acc | MulticlassAccuracy | 0     
4 | val_acc   | MulticlassAccuracy | 0     
-------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
22.381    Total estimated model params size (MB)


Adjusting learning rate of group 0 to 1.0000e-03.


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.


Validating: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.


Validating: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.


Validating: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.0000e-03.


Validating: 0it [00:00, ?it/s]