In [17]:
import lightning as L
import torch
import torch.nn as nn
import torchmetrics
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torch.utils.data import random_split, DataLoader, Dataset
from torchvision.transforms import transforms, Compose
import config_model
from config_model import target_transforms, QrDBest, early_stopping_callback, check_point_callback

In [10]:
ResNet50 = models.resnet50(weights="IMAGENET1K_V2")

In [18]:

class LightningQrDBest(L.LightningDataModule):
    def __init__(self):
        super().__init__()
        
        self.img_dir = config_model.IMAGE_DIR 
        
        self.train_transforms= Compose(config_model.train_transform_compose)
        
        self.val_transforms = Compose(config_model.test_transform_compose)
        
        self.target_transforms = target_transforms
        
        self.batch_size = config_model.BATCH_SIZE
        
    def setup(self, stage: str):
        if stage == 'fit':
            self.full_dataset = QrDBest(self.img_dir)
            self.train_size = int( 0.9 * len(self.full_dataset))
            self.val_size = len(self.full_dataset) - self.train_size
            self.train_dataset, self.val_dataset = random_split(self.full_dataset, [self.train_size, self.val_size],torch.Generator().manual_seed(50))
            self.train_dataset.dataset.img_transforms = self.train_transforms
            self.train_dataset.dataset.target_transforms = self.target_transforms
            self.val_dataset.dataset.img_transforms = self.val_transforms
            self.val_dataset.dataset.target_transforms = self.target_transforms
            
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=config_model.NUM_WORKERS)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=config_model.NUM_WORKERS)

In [33]:

class Net(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.example_input_array = torch.randn(1, 1, 64, 64)
        self.model = models.resnet50(weights="IMAGENET1K_V2")
        num_classes = 4
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.model.conv1 = nn.Conv2d(
        in_channels=1,  # Change input channels to 1 for grayscale images
        out_channels=self.model.conv1.out_channels,
        kernel_size=self.model.conv1.kernel_size,
        stride=self.model.conv1.stride,
        padding=self.model.conv1.padding,
        bias=self.model.conv1.bias
        )
        self.loss_function = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy('multiclass', num_classes=4)
        self.f1_score = torchmetrics.F1Score('multiclass', num_classes=4)
    def forward(self, x):
        x = self.model(x)
        
        return x
    
    def _common_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), 1, 64, 64)
        output = self.forward(x)
        loss = self.loss_function(output, y)
        preds = torch.argmax(output, dim=1)
        return loss, preds, y
    
    def training_step(self, batch, batch_idx):
        loss, preds, y = self._common_step(batch, batch_idx)
        y = torch.argmax(y, dim=1)
        train_accuracy = self.accuracy(preds, y)
        train_f1_score = self.f1_score(preds, y)
        self.log_dict({'train_loss': loss, 'train_accuracy': train_accuracy, 'train_f1_score': train_f1_score},
                      prog_bar=True, on_step=False, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, preds, y = self._common_step(batch, batch_idx)
        y = torch.argmax(y, dim=1)
        val_accuracy = self.accuracy(preds, y)
        val_f1_score = self.f1_score(preds, y)
        self.log_dict({"val_loss": loss, "val_accuracy": val_accuracy, "val_f1_score": val_f1_score},
                      prog_bar=True, on_epoch=True, on_step=False)
        return loss
    
    def test_step(self, batch, batch_idx):
        loss, scores, y = self._common_step(batch, batch_idx)
        self.log('test_loss', loss)
        return loss
    
    def predict_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), 1, 64, 64)
        scores = self.forward(x)
        preds = torch.argmax(scores, dim=1)
        return preds
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)


In [34]:
model = Net()

In [21]:
dm = LightningQrDBest()

In [35]:
trainer = L.Trainer(accelerator="gpu", devices=[0], precision=16, callbacks=[early_stopping_callback, check_point_callback],
                    max_epochs=50, min_epochs=1)

/home/parsa/.local/lib/python3.10/site-packages/lightning/fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [36]:
trainer.fit(model=model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params | Mode  | In sizes       | Out sizes
------------------------------------------------------------------------------------------
0 | model         | ResNet             | 23.5 M | train | [1, 1, 64, 64] | [1, 4]   
1 | loss_function | CrossEntropyLoss   | 0      | train | ?              | ?        
2 | accuracy      | MulticlassAccuracy | 0      | train | ?              | ?        
3 | f1_score      | MulticlassF1Score  | 0      | train | ?              | ?        
------------------------------------------------------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.040    Total estimated model params size (MB)
154       Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 122/122 [00:03<00:00, 31.18it/s, v_num=2, val_loss=0.168, val_accuracy=0.944, val_f1_score=0.944, train_loss=0.489, train_accuracy=0.801, train_f1_score=0.801]

Metric val_loss improved. New best score: 0.168


Epoch 1: 100%|██████████| 122/122 [00:03<00:00, 34.75it/s, v_num=2, val_loss=0.0799, val_accuracy=0.970, val_f1_score=0.970, train_loss=0.0971, train_accuracy=0.967, train_f1_score=0.967]

Metric val_loss improved by 0.088 >= min_delta = 0.008. New best score: 0.080


Epoch 3: 100%|██████████| 122/122 [00:03<00:00, 34.96it/s, v_num=2, val_loss=0.0357, val_accuracy=0.986, val_f1_score=0.986, train_loss=0.048, train_accuracy=0.986, train_f1_score=0.986] 

Metric val_loss improved by 0.044 >= min_delta = 0.008. New best score: 0.036


Epoch 5: 100%|██████████| 122/122 [00:03<00:00, 35.43it/s, v_num=2, val_loss=0.025, val_accuracy=0.991, val_f1_score=0.991, train_loss=0.0655, train_accuracy=0.985, train_f1_score=0.985] 

Metric val_loss improved by 0.011 >= min_delta = 0.008. New best score: 0.025


Epoch 14: 100%|██████████| 122/122 [00:03<00:00, 34.77it/s, v_num=2, val_loss=0.0494, val_accuracy=0.986, val_f1_score=0.986, train_loss=0.0185, train_accuracy=0.994, train_f1_score=0.994] 

Monitored metric val_loss did not improve in the last 9 records. Best score: 0.025. Signaling Trainer to stop.


Epoch 14: 100%|██████████| 122/122 [00:03<00:00, 34.74it/s, v_num=2, val_loss=0.0494, val_accuracy=0.986, val_f1_score=0.986, train_loss=0.0185, train_accuracy=0.994, train_f1_score=0.994]
