In [8]:
# Based on:
# 1. https://www.kaggle.com/code/devashishpandit/flowers-classification-with-pytorch - dataset processing
# 2. https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html - retraining the network

In [1]:

import numpy as np
import os
import pytorch_lightning as pl
import shutil
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning import Trainer
from torchvision import models

## Data processing

In [2]:
train_dir = './train'
val_dir = './val'

In [8]:
train_ratio = 0.80

# # only run once

_, dirs, _ = next(os.walk(train_dir))
images_per_class = np.zeros(5)
for i in range(len(images_per_class)):
    path = os.path.join(train_dir,dirs[i])
    files = np.asarray(os.listdir(path))
    images_per_class[i] = len(files)

val_counter = np.round(images_per_class * (1-train_ratio))

# transfer files
for i in range(len(images_per_class)):
    source_path = os.path.join(train_dir, dirs[i])
    dest_path = os.path.join(val_dir, dirs[i])
    if not os.path.exists(dest_path):
        os.makedirs(dest_path)
    files = np.asarray(os.listdir(source_path))
    for j in range(int(val_counter[i])):
        dst = os.path.join(dest_path, files[j])
        src = os.path.join(source_path, files[j])
        shutil.move(src,dst)

In [3]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

In [4]:
train_data = torchvision.datasets.ImageFolder(root=train_dir, transform=transform)
val_data = torchvision.datasets.ImageFolder(root=val_dir, transform=transform)

In [5]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=16, shuffle=False)

## Loading & modifying AlexNet

In [62]:
class LitAlexNet(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.net = models.alexnet(pretrained=True)
        for param in self.net.parameters():
            param.requires_grad = False
        self.net.classifier[6] = nn.Linear(in_features=4096, out_features=5, bias=True)

    def forward(self, x):
        outputs = self.net(x)
        return outputs

    def loss_fn(self, out, target):
        return nn.CrossEntropyLoss()(out, target)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.net.parameters(), lr=0.005)
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        loss = self.loss_fn(outputs, y)
        return {'loss': loss}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('train_loss',  avg_loss)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        outputs = self(x)
        loss = self.loss_fn(outputs, y)
        _, preds = torch.max(outputs, 1)
        correct_preds = torch.sum(preds == y.data)
        return {"correct": correct_preds, "loss": loss, "total": len(y.data)}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        correct=sum([x["correct"] for  x in outputs])
        total=sum([x["total"] for  x in outputs])
        self.log('validation_accuracy', correct/total)
        self.log('validation_loss',  avg_loss)

In [68]:
model = LitAlexNet()

In [69]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='validation_accuracy',
    dirpath='./',
    filename='models-{epoch:02d}-{validation_accuracy:.2f}',
    save_top_k=3,
    mode='max')

In [70]:
trainer = Trainer(gpus=1, max_epochs=10, callbacks=[checkpoint_callback])
trainer.validate(model, val_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validating: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'validation_accuracy': 0.20717592537403107,
 'validation_loss': 1.803727388381958}
--------------------------------------------------------------------------------


[{'validation_accuracy': 0.20717592537403107,
  'validation_loss': 1.803727388381958}]

In [71]:
trainer.fit(model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type    | Params
---------------------------------
0 | net  | AlexNet | 57.0 M
---------------------------------
20.5 K    Trainable params
57.0 M    Non-trainable params
57.0 M    Total params
228.097   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [73]:
best_model = LitAlexNet().load_from_checkpoint('./models-epoch=07-validation_accuracy=0.84.ckpt')

In [74]:
trainer.validate(best_model, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validating: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'validation_accuracy': 0.8356481790542603,
 'validation_loss': 1.1515005826950073}
--------------------------------------------------------------------------------


[{'validation_accuracy': 0.8356481790542603,
  'validation_loss': 1.1515005826950073}]