In [2]:
import os
import sys
import torch

import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd
import lightning.pytorch as pl
import numpy as np
import seaborn as sns


from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks import EarlyStopping
from torchmetrics import Accuracy
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision import datasets, transforms
from torchvision.models.vision_transformer import vit_b_16
from torchvision.models import ViT_B_16_Weights
from PIL import Image
from lightning.pytorch.tuner import Tuner
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score



In [2]:
DATA_DIR = os.path.join(os.curdir, "data", "raw", "dataset")

In [3]:
TRAIN_SIZE = 0.8
BATCH_SIZE = 32

In [4]:
class VitLightningModule(pl.LightningModule):
    def __init__(self, model, learning_rate=1.2e-4, num_classes=11):
        super(VitLightningModule, self).__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
        self.transform = ViT_B_16_Weights.DEFAULT.transforms()
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        acc = self.accuracy(y_hat.softmax(dim=-1), y)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        acc = self.accuracy(y_hat.softmax(dim=-1), y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        acc = self.accuracy(y_hat.softmax(dim=-1), y)
        self.log('test_loss', loss)
        self.log('test_acc', acc)


    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        

In [5]:
class ImageFolderDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, train_size=0.8):
        super().__init__()
        self.data_dir = DATA_DIR
        self.batch_size = batch_size
        self.train_size = train_size
        self.transform = ViT_B_16_Weights.DEFAULT.transforms()

    def setup(self, stage=None):
        dataset = datasets.ImageFolder(self.data_dir, transform=self.transform)
        train_size = int(self.train_size * len(dataset))
        val_size = int(0.1 * len(dataset))
        test_size = len(dataset) - train_size - val_size
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(dataset, [train_size, val_size, test_size])
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True , num_workers=31)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=31)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

In [6]:
def plot_loss_curves(logger):
    metrics_path = os.path.join(logger.save_dir, logger.name, f"version_{logger.version}", "metrics.csv")
    
    if not os.path.exists(metrics_path):
        raise FileNotFoundError(f"Metrics file not found at: {metrics_path}")

    metrics = pd.read_csv(metrics_path)

    plt.figure(figsize=(10, 5))
    train_loss = metrics.dropna(subset=['train_loss'])
    val_loss = metrics.dropna(subset=['val_loss'])
    
    plt.plot(train_loss['epoch'], train_loss['train_loss'], label='Training Loss')
    plt.plot(val_loss['epoch'], val_loss['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training vs Validation Loss')
    plt.show()

In [7]:
model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
model.heads.head = nn.Linear(model.heads.head.in_features, 11)

pl_model = VitLightningModule(model)

data_module = ImageFolderDataModule(DATA_DIR, batch_size=BATCH_SIZE, train_size=TRAIN_SIZE)

csv_logger = pl_loggers.CSVLogger('logs/', name='csv_logs')

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.05, patience=4, mode="min")

trainer = pl.Trainer(callbacks=[early_stop_callback], max_epochs = 120, devices = 1, accelerator='gpu', logger = csv_logger, log_every_n_steps=10)

#very time consuming to find optimal batch size and learning rate, just use first time and then hardcode it
'''
# finding optimal batch size
# maximum batch size was : 128, but takes too long to train so using 32 instead for fast testing purposes
tuner = Tuner(trainer)
new_batch_size = tuner.scale_batch_size(pl_model, datamodule=data_module, mode="power")

# Update data module with new batch size
data_module.batch_size = new_batch_size

# Find the optimal learning rate
# optimal learning rate was 0.00012022644346174131
lr_finder = tuner.lr_find(pl_model, datamodule=data_module)
new_lr = lr_finder.suggestion()
pl_model.learning_rate = new_lr
'''
    

    
trainer.fit(pl_model, data_module)
trainer.test(pl_model, data_module)

plot_loss_curves(csv_logger)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                | Params | Mode 
----------------------------------------------------------
0 | model     | VisionTransformer   | 85.8 M | train
1 | criterion | CrossEntropyLoss    | 0      | train
2 | accuracy  | MulticlassAccuracy  | 0      | train
3 | transform | ImageClassification | 0      | train
----------------------------------------------------------
85.8 M    Trainable params
0         Non-trainable params
85.8 M    Total params
34

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
test_dataloader = data_module.test_dataloader()

test_model = pl_model.model
test_model.eval()

def compute_metrics(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)
    
    TP = np.diag(cm)
    FP = cm.sum(axis=0) - TP
    FN = cm.sum(axis=1) - TP
    TN = cm.sum() - (FP + FN + TP)

    return cm, all_labels, all_preds, TP, FP, FN, TN

def calculate_macro_micro_f1(all_labels, all_preds):
    macro_f1 = f1_score(all_labels, all_preds, average='macro')
    micro_f1 = f1_score(all_labels, all_preds, average='micro')
    return macro_f1, micro_f1

conf_matrix, all_labels, all_preds, TP, FP, FN, TN = compute_metrics(test_model, test_dataloader)

cm_df = pd.DataFrame(conf_matrix, index =[i for i in ['dew', 'fogsmog', 'frost', 'glaze', 'hail', 'lightning', 'rain', 'rainbow', 'rime', 'sandstorm', 'snow']],
                     columns=[i for i in ['dew', 'fogsmog', 'frost', 'glaze', 'hail', 'lightning', 'rain', 'rainbow', 'rime', 'sandstorm', 'snow']])

plt.figure(figsize=(10,7))
sns.heatmap(cm_df, annot=True, fmt='d', cmap='YlGnBu')
plt.title('Confusion Matrix')
plt.ylabel('Actual Values')
plt.xlabel('Predicted Values')
plt.show()

accuracy = (TP.sum() + TN.sum()) / (TP.sum() + TN.sum() + FP.sum() + FN.sum())
precision = TP / (TP + FP)
recall = TP / (TP + FN)
f1_score_per_class = 2 * (precision * recall) / (precision + recall)

macro_f1, micro_f1 = calculate_macro_micro_f1(all_labels, all_preds)

print(f"Accuracy: {accuracy:.4f}")
print(f"Precision per class: {precision}")
print(f"Recall per class: {recall}")
print(f"F1 Score per class: {f1_score_per_class}")
print(f"Macro F1 Score: {macro_f1:.4f}")
print(f"Micro F1 Score: {micro_f1:.4f}")