In [1]:
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
import nltk
import pandas as pd
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
import csv
import torchvision.models as models
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from nltk.translate.bleu_score import sentence_bleu
import timm
import pymeteor.pymeteor
import random
from torchinfo import summary
from glob import glob
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from torchvision.transforms import ToTensor
# Device configurationresul
tf = ToTensor()
device = torch.device('cuda:6' if torch.cuda.is_available() else 'cpu')

In [2]:
params={'image_size':1024,
        'lr':2e-4,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':4,
        'epochs':10000,
        'data_path':'../../data/synth/type/',
        'train_csv':'BR_train.csv',
        'val_csv':'BR_test.csv',
        'vocab_path':'../../data/synth/type/BR_vocab.pkl',
        'embed_size':300,
        'hidden_size':256,
        'num_layers':4,
        'class_count':15,}

In [3]:
class CustomDataset(Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self,data_list,label_list):
        """Set the path for images, captions and vocabulary wrapper.
        
        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.data_list=data_list
        self.label_list=label_list
    def trans(self,image):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            image = transform(image)
            
        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            image = transform(image)
            
        return image
    
        
    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        images = self.data_list[index]
        labels=self.label_list[index]
        return images, labels

    def __len__(self):
        return len(self.data_list)
    
    
tf=transforms.ToTensor()
df=pd.read_csv(params['data_path']+params['train_csv'])
train_list=torch.zeros(len(df),3,params['image_size'],params['image_size'])
train_label_list=torch.zeros(len(df),params['class_count'])
for i in tqdm(range(len(df))):
    image=tf(Image.open(glob(params['data_path']+'**/**/'+df.loc[i]['path'])[0]).resize((params['image_size'],params['image_size'])))
    label=os.path.basename(os.path.dirname(glob(params['data_path']+'**/**/'+df.loc[i]['path'])[0]))[2:]
    train_list[i]=image
    train_label_list[i,int(label)-1]=1
df=pd.read_csv(params['data_path']+params['val_csv'])
test_list=torch.zeros(len(df),3,params['image_size'],params['image_size'])
test_label_list=torch.zeros(len(df),params['class_count'])
for i in tqdm(range(len(df))):
    image=tf(Image.open(glob(params['data_path']+'**/**/'+df.loc[i]['path'])[0]).resize((params['image_size'],params['image_size'])))
    label=os.path.basename(os.path.dirname(glob(params['data_path']+'**/**/'+df.loc[i]['path'])[0]))[2:]
    test_list[i]=image
    test_label_list[i,int(label)-1]=1
train_dataset=CustomDataset(train_list,train_label_list)
test_dataset=CustomDataset(test_list,test_label_list)
train_dataloader=DataLoader(train_dataset,batch_size=params['batch_size'],shuffle=True)
val_dataloader=DataLoader(test_dataset,batch_size=params['batch_size'],shuffle=True)

100%|██████████| 8332/8332 [16:49<00:00,  8.26it/s]
100%|██████████| 1042/1042 [01:36<00:00, 10.75it/s]


In [4]:
class FeatureExtractor(nn.Module):
    """Feature extoractor block"""
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        cnn1= timm.create_model('efficientnetv2_xl')
        self.feature_ex = nn.Sequential(*list(cnn1.children())[:-1])

    def forward(self, inputs):
        features = self.feature_ex(inputs)
        
        return features
    
class AttentionMILModel(nn.Module):
    def __init__(self, num_classes, image_feature_dim,feature_extractor_scale1: FeatureExtractor):
        super(AttentionMILModel, self).__init__()
        self.num_classes = num_classes
        self.image_feature_dim = image_feature_dim

        # Remove the classification head of the CNN model
        self.feature_extractor = feature_extractor_scale1
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(image_feature_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )
        
        # Classification layer
        self.classification_layer = nn.Linear(image_feature_dim, num_classes)

    def forward(self, inputs):
        batch_size, channels, height, width = inputs.size()
        
        # Flatten the inputs
        inputs = inputs.view(-1, channels, height, width)
        
        # Feature extraction using the pre-trained CNN
        features = self.feature_extractor(inputs)  # Shape: (batch_size , 2048, 1, 1)
        
        # Reshape features
        features = features.view(batch_size, -1)  # Shape: (batch_size, num_tiles, 2048)
        
        
        
        
        
        # Classification layer
        logits = self.classification_layer(features)  # Shape: (batch_size, num_classes)
        
        return logits  


Feature_Extractor=FeatureExtractor()
encoder = AttentionMILModel(params['class_count'], 1280, Feature_Extractor).to(device)
criterion = nn.CrossEntropyLoss()
model_param =list(encoder.parameters())
optimizer = torch.optim.Adam(model_param, lr=params['lr'], betas=(params['beta1'], params['beta2']))
# summary(encoder, input_size=(params['batch_size'], 3, params['image_size'], params['image_size']))

In [5]:
plt_count = 0
sum_loss = 1000.0
scheduler = 0.90
teacher_forcing = 0.3
for epoch in range(params['epochs']):
    train = tqdm(train_dataloader)
    count = 0
    train_loss = 0.0
    correct_train = 0
    total_train = 0
    for images, label in train:
        count += 1
        images = images.to(device)
        label = label.to(device)
        features = F.softmax(encoder(images), dim=1)
        loss = criterion(features, label)
        encoder.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        # Accuracy calculation for training
        _, predicted = torch.max(features, 1)
        _, label_1= torch.max(label, 1)
        correct_train += (predicted == label_1).sum().item()
        total_train += label_1.size(0)
        
        train.set_description(f"train epoch: {epoch + 1}/{params['epochs']} Step: {count + 1} loss: {train_loss / count:.4f} accuracy: {100 * correct_train / total_train:.2f}%")

    with torch.no_grad():
        val_count = 0
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        val = tqdm(val_dataloader)
        for images, label in val:
            val_count += 1
            images = images.to(device)
            label = label.to(device)
            features = F.softmax(encoder(images), dim=1)
            loss = criterion(features, label)
            val_loss += loss.item()
            
            # Accuracy calculation for validation
            _, predicted = torch.max(features, 1)
            _, label_1= torch.max(label, 1)
            correct_val += (predicted == label_1).sum().item()
            total_val += label_1.size(0)
            
            val.set_description(f"val epoch: {epoch + 1}/{params['epochs']} Step: {val_count + 1} loss: {(val_loss / val_count):.4f} accuracy: {100 * correct_val / total_val:.2f}%")

    if val_loss < sum_loss:
        sum_loss = val_loss
        torch.save(encoder.state_dict(), '../../model/captioning/BR_encoder1_check.pth')

train epoch: 1/10000 Step: 2084 loss: 2.3870 accuracy: 43.10%: 100%|██████████| 2083/2083 [20:18<00:00,  1.71it/s]
val epoch: 1/10000 Step: 262 loss: 2.4376 accuracy: 37.24%: 100%|██████████| 261/261 [00:53<00:00,  4.85it/s]
train epoch: 2/10000 Step: 2084 loss: 2.5844 accuracy: 23.13%: 100%|██████████| 2083/2083 [20:55<00:00,  1.66it/s]
val epoch: 2/10000 Step: 262 loss: 2.6008 accuracy: 21.59%: 100%|██████████| 261/261 [00:51<00:00,  5.02it/s]
train epoch: 3/10000 Step: 2084 loss: 2.6004 accuracy: 21.59%: 100%|██████████| 2083/2083 [19:43<00:00,  1.76it/s]
val epoch: 3/10000 Step: 262 loss: 2.6040 accuracy: 21.31%: 100%|██████████| 261/261 [00:57<00:00,  4.57it/s]
train epoch: 4/10000 Step: 2084 loss: 2.5802 accuracy: 23.62%: 100%|██████████| 2083/2083 [19:58<00:00,  1.74it/s]
val epoch: 4/10000 Step: 262 loss: 2.6055 accuracy: 21.21%: 100%|██████████| 261/261 [00:55<00:00,  4.72it/s]
train epoch: 5/10000 Step: 798 loss: 2.5566 accuracy: 26.10%:  38%|███▊      | 797/2083 [07:28<12:03

KeyboardInterrupt: 

In [None]:
predicted