In [None]:
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
import pytorch_lightning as pl
import config
from PIL import Image
from torch import nn,optim
import pytorch_lightning as pl
import timm
import torch
from torch.utils.data import Dataset, DataLoader
import torchmetrics
from model import CNNModel
from EfficientNetB0 import EfficientNet
from Xception import XceptionNet
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from dataloader import LogoDataModule
from lightning.pytorch import Trainer, seed_everything
import config

## **Image Augmentations**

In [None]:
# initialize our data augmentation functions
resize = transforms.Resize(size=(224,224))
hFlip = transforms.RandomHorizontalFlip(p=0.25)
vFlip = transforms.RandomVerticalFlip(p=0.25)
rotate = transforms.RandomRotation(degrees=15)
coljtr = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1)
raf = transforms.RandomAffine(degrees=40, translate=None, scale=(1, 2), shear=15)
rrsc = transforms.RandomResizedCrop(size=224, scale=(0.8, 1.0))
ccp  = transforms.CenterCrop(size=224)  # Image net standards
nrml = transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])  # Imagenet standards


In [None]:
trainTransforms = transforms.Compose([resize,hFlip,vFlip,rotate,raf,rrsc,ccp,coljtr,transforms.ToTensor(),nrml])
valTransforms = transforms.Compose([resize,transforms.ToTensor(),nrml])

## **Data Loader for Image Folder**

In [None]:
class LogoDataModule(pl.LightningDataModule):
    def __init__(self, data_folder, batch_size=32, num_workers=4, val_split=0.2):
        super(LogoDataModule, self).__init__()
        self.data_folder = data_folder
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split

        self.train_transform = trainTransforms
        self.val_transform = valTransforms

    def setup(self, stage=None):
        # Create dataset without transformations
        self.dataset = ImageFolder(root=self.data_folder)

        # Split dataset into training and validation sets
        val_size = int(len(self.dataset) * self.val_split)
        train_size = len(self.dataset) - val_size
        self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size])

        # Apply transformations to the datasets
        self.train_dataset.dataset.transform = self.train_transform
        self.val_dataset.dataset.transform = self.val_transform

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

#### **Create the Data Module**

In [None]:
data_module = LogoDataModule(data_folder=config.DATA_FOLDER,
                            batch_size=config.BATCH_SIZE,
                            val_split=config.VAL_SPLIT)

data_module.setup()

## **Custom Dataset**

In [None]:
class NQADataset(Dataset):
  def __init__(self,data ,tokenizer ,source_max_token_len = 400,target_max_token_len = 32):

    self.tokenizer = tokenizer
    self.data = data
    self.source_max_token_len = source_max_token_len
    self.target_max_token_len = target_max_token_len

  def __len__(self):
    return len(self.data)
  
  def __getitem__(self,index : int):
    data_row = self.data.iloc[index]

    source_encoding = self.tokenizer(
        data_row['question'],
        data_row['paragraph'],
        max_length = self.source_max_token_len,
        padding = "max_length",
        truncation = "only_second",
        return_attention_mask = True,
        add_special_tokens = True,
        return_tensors = "pt")
    
    target_encoding = self.tokenizer(
        data_row['answer'],
        max_length = self.target_max_token_len,
        padding = "max_length",
        truncation = True,
        return_attention_mask = True,
        add_special_tokens = True,
        return_tensors = "pt")
    
    labels = target_encoding["input_ids"]
    labels[labels == 0] = -100

    return dict(
        answer = data_row['answer'],
        input_ids = source_encoding['input_ids'].flatten(),
        attention_mask = source_encoding['attention_mask'].flatten(),
        labels = labels.flatten())

## **Custom Data Module**

In [None]:
class NQADataModule(pl.LightningDataModule):
  
  def __init__(self,train_df,val_df,test_df,MODEL_NAME,batch_size : int = 8,source_max_token_len : int = 400,target_max_token_len : int = 32):
    super().__init__()
    self.batch_size = batch_size
    self.train_df = train_df
    self.test_df = test_df
    self.val_df = val_df
    self.MODEL_NAME = MODEL_NAME
    self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
    self.source_max_token_len = source_max_token_len
    self.target_max_token_len = target_max_token_len

  def setup(self,stage=None):
    self.train_dataset = NQADataset(self.train_df,self.tokenizer,self.source_max_token_len,self.target_max_token_len)
    self.val_dataset = NQADataset(self.val_df,self.tokenizer,self.source_max_token_len,self.target_max_token_len)
    self.test_dataset = NQADataset(self.test_df,self.tokenizer,self.source_max_token_len,self.target_max_token_len)
    

  def train_dataloader(self):
    return DataLoader(self.train_dataset,batch_size = self.batch_size,shuffle=True,num_workers=4)

  def val_dataloader(self):
    return DataLoader(self.val_dataset,batch_size = self.batch_size,num_workers=4)

  def test_dataloader(self):
    return DataLoader(self.test_dataset,batch_size = self.batch_size,num_workers=4)

In [None]:
data_module = NQADataModule(train_df,val_df,test_df,MODEL_NAME,batch_size = BATCH_SIZE)
data_module.setup()

## **Model Architecture for Language Model**

In [None]:
class NQAModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME,return_dict=True)

    def forward(self,input_ids,attention_mask,labels=None):
        output = self.model(
            input_ids = input_ids,
            attention_mask = attention_mask,
            labels = labels)
        
        return output.loss, output.logits

    def training_step(self,batch,batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        loss, outputs = self(input_ids,attention_mask,labels)
        self.log("train_loss",loss,prog_bar=True,logger=True)
        return loss

    def validation_step(self,batch,batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        loss, outputs = self(input_ids,attention_mask,labels)
        self.log("val_loss",loss,prog_bar=True,logger=True)
        return loss

    def test_step(self,batch,batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        loss, outputs = self(input_ids,attention_mask,labels)
        self.log("test_loss",loss,prog_bar=True,logger=True)
        return loss

    def configure_optimizers(self):
        return AdamW(self.parameters(),lr = 0.0001)


## **Model Architecture for CNN**

In [None]:
class XceptionNet(pl.LightningModule):
    def __init__(self,num_classes,lr):
        super(XceptionNet, self).__init__()
    
        self.Lr = lr
        
        self.validation_step_outputs = []
        
        self.lossfn = nn.NLLLoss()
        
        self.acc = torchmetrics.Accuracy(task="multiclass",num_classes=num_classes)

        
        self.model = timm.create_model('xception', pretrained=True)
        self.model.aux_logits=False

        # Freeze training for all layers
        for param in self.model.parameters():
            param.requires_grad = False

        self.model.fc = nn.Sequential(
                            nn.BatchNorm1d(self.model.fc.in_features),
                            nn.Linear(self.model.fc.in_features, 256),
                            nn.Dropout(0.2),
                            nn.ReLU(inplace=True),
                            nn.BatchNorm1d(256),
                            nn.Linear(256, num_classes),
                            nn.LogSoftmax(dim=1))
        
    def forward(self, x):
        out = self.model(x)
        return out
    
    def training_step(self,batch,batch_idx):
        input, label = batch
        output = self(input)
        loss = self.lossfn(output,label)
        return loss
    
    def validation_step(self,batch,batch_idx):
        input, label = batch
        output = self(input)
        loss = self.lossfn(output,label)
        self.validation_step_outputs.append(loss)
        self.log("val_loss", loss)
        y_pred = torch.argmax(output,dim=1)
        #pred = output.data.max(1, keepdim=True)[1]
        self.acc.update(y_pred, label)
        
    def on_validation_epoch_end(self):
        mean_val = torch.mean(torch.tensor(self.validation_step_outputs))
        self.log('mean_val', mean_val)
        self.validation_step_outputs.clear()  # free memory
        val_accuracy = self.acc.compute()
        self.log("val_accuracy", val_accuracy)
        # reset all metrics
        self.acc.reset()
        print(f"\nVal Accuracy: {val_accuracy:.4} "\
        f"Val Loss: {mean_val:.4}")
        
        
    def configure_optimizers(self):
        return optim.AdamW(self.parameters(),lr=self.Lr)

In [None]:
seed_everything(42, workers=True)

model = XceptionNet(num_classes=config.NUM_CLASSES,lr=config.LR)


checkpoint_callback = ModelCheckpoint(
    dirpath = 'checkpoints',
    filename = config.CHECKPOINT_NAME,
    save_top_k = 1,
    verbose = True,
    monitor = 'mean_val',
    mode = 'min'
)


## **Trainer**

In [None]:
trainer = pl.Trainer(devices=-1, 
                  accelerator="gpu",
                  check_val_every_n_epoch=5,
                  callbacks=[checkpoint_callback],
                  max_epochs=config.MAX_EPOCHS)


trainer.fit(model=model,datamodule=data_module)