In [1]:
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
import nltk
import pandas as pd
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
import csv
import torchvision.models as models
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from nltk.translate.bleu_score import sentence_bleu
import timm
import random
from torchinfo import summary
from glob import glob
from torchvision.transforms import ToTensor
nltk.download('punkt')
tf = ToTensor()
# Device configurationresul
device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')

[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
params={'image_size':1024,
        'lr':2e-4,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':8,
        'epochs':10000,
        'data_path':'../../data/synth/type/',
        'train_csv':'BR_train.csv',
        'val_csv':'BR_val.csv',
        'vocab_path':'../../data/synth/type/BR_vocab.pkl',
        'embed_size':300,
        'hidden_size':256,
        'num_layers':1,}

In [3]:
class CustomDataset(Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self,data_list, data_path,image_size, csv, class_dataset, vocab, transform=None):
        """Set the path for images, captions and vocabulary wrapper.
        
        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.root = data_path+'**/**/'
        self.df = pd.read_csv(data_path+csv)
        self.class_dataset=class_dataset
        self.vocab = vocab
        self.transform = transform
        self.image_size=image_size
        self.data_list=data_list
        
    def trans(self,image):
        if random.random() > 0.5:
            transform = transforms.RandomHorizontalFlip(1)
            image = transform(image)
            
        if random.random() > 0.5:
            transform = transforms.RandomVerticalFlip(1)
            image = transform(image)
            
        return image
    
    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        df = self.df
        vocab = self.vocab
        img_id=df.loc[index]
        
        caption=img_id['caption']
        images = self.trans(self.data_list[index])
        # Convert caption (string) to word ids.
        
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return images, target

    def __len__(self):
        return len(self.data_list)
    

class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).
    
    We should build custom collate_fn rather than using default collate_fn, 
    because merging caption (including padding) is not supported in default.

    Args:
        data: list of tuple (image, caption). 
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.

    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets, lengths

def idx2word(vocab, indices):
    sentence = []
    
    aa=indices.cpu().numpy()
    
    for index in aa:
        word = vocab.idx2word[index]
        sentence.append(word)
    return sentence
def word2sentence(words_list):
    sentence=''
    for word in words_list:
        if word.isalnum():
            sentence+=' '+word
        else:
            sentence+=word
    return sentence

In [4]:

class FeatureExtractor(nn.Module):
    """Feature extoractor block"""
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        cnn1= timm.create_model('efficientnetv2_s')
        self.feature_ex = nn.Sequential(*list(cnn1.children())[:-1])

    def forward(self, inputs):
        features = self.feature_ex(inputs)
        
        return features
    
class AttentionMILModel(nn.Module):
    def __init__(self, num_classes, image_feature_dim,feature_extractor_scale1: FeatureExtractor):
        super(AttentionMILModel, self).__init__()
        self.num_classes = num_classes
        self.image_feature_dim = image_feature_dim

        # Remove the classification head of the CNN model
        self.feature_extractor = feature_extractor_scale1
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(image_feature_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )
        
        # Classification layer
        self.classification_layer = nn.Linear(image_feature_dim, num_classes)

    def forward(self, inputs):
        batch_size, channels, height, width = inputs.size()
        
        # Flatten the inputs
        inputs = inputs.view(-1, channels, height, width)
        
        # Feature extraction using the pre-trained CNN
        features = self.feature_extractor(inputs)  # Shape: (batch_size , 2048, 1, 1)
        
        # Reshape features
        features = features.view(batch_size, -1)  # Shape: (batch_size, num_tiles, 2048)
        
        
        
        
        
        # Classification layer
        logits = self.classification_layer(features)  # Shape: (batch_size, num_classes)
        
        return logits  
    
class DecoderTransformer(nn.Module):
    def __init__(self, embed_size, vocab_size, num_heads, hidden_size, num_layers, max_seq_length=100):
        super(DecoderTransformer, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_length, embed_size))
        self.max_seq_length = max_seq_length
        
        # Transformer Decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads, dim_feedforward=hidden_size)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        self.linear = nn.Linear(embed_size, vocab_size)
        
    def forward(self, features, captions, lengths, teacher_forcing_ratio=0.5):
        batch_size, seq_len = captions.size()
        outputs = torch.zeros(batch_size, seq_len, self.linear.out_features).to(captions.device)
        
        # Positional encoding을 더해 임베딩 생성
        captions_embedded = self.embed(captions) + self.positional_encoding[:, :seq_len, :]
        
        # features의 차원을 (batch_size, 1, embed_size)로 맞춤
        features = features.unsqueeze(1)
        
        # Transformer는 (seq_len, batch_size, embed_size)로 입력을 받으므로 차원 변경
        memory = features.permute(1, 0, 2)  # (1, batch_size, embed_size)
        
        input_caption = captions[:, 0].unsqueeze(1)  # Start with the first token
        for t in range(1, seq_len):
            input_embedded = self.embed(input_caption) + self.positional_encoding[:, :input_caption.size(1), :]
            input_embedded = input_embedded.permute(1, 0, 2)  # (seq_len, batch_size, embed_size)
            
            # Transformer Decoder에 입력
            transformer_output = self.transformer_decoder(input_embedded, memory)
            
            # 다시 차원을 (batch_size, seq_len, embed_size)로 변경 후 Linear layer에 전달
            transformer_output = transformer_output.permute(1, 0, 2)
            output = self.linear(transformer_output[:, -1, :])  # (batch_size, vocab_size)
            
            outputs[:, t, :] = output
            
            # Teacher forcing 결정
            use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
            top1 = output.argmax(1)
            input_caption = captions[:, t].unsqueeze(1) if use_teacher_forcing else top1.unsqueeze(1)
        
        return outputs
    
    def sample(self, features, max_seq_length=None):
        """Greedy Search 방식으로 시퀀스를 샘플링합니다."""
        if max_seq_length is None:
            max_seq_length = self.max_seq_length
        
        # 샘플링을 위한 기본 설정
        inputs = features.unsqueeze(1)  # (batch_size, 1, embed_size)
        sampled_ids = []
        
        # 첫 번째 토큰은 <start> 토큰으로 간주 (일반적으로 ID는 1로 설정)
        input_tokens = torch.ones(features.size(0), 1).long().to(features.device)
        
        for _ in range(max_seq_length):
            # 임베딩 및 positional encoding 적용
            embedded_tokens = self.embed(input_tokens) + self.positional_encoding[:, :input_tokens.size(1), :]
            
            # Transformer는 (seq_len, batch_size, embed_size) 형태의 입력이 필요함
            embedded_tokens = embedded_tokens.permute(1, 0, 2)
            memory = features.unsqueeze(1).permute(1, 0, 2)
            
            # Transformer 디코더를 사용하여 출력 생성
            transformer_output = self.transformer_decoder(embedded_tokens, memory)
            transformer_output = transformer_output.permute(1, 0, 2)
            
            # Linear layer로 vocab 크기로 변환
            output = self.linear(transformer_output[:, -1, :])  # (batch_size, vocab_size)
            _, predicted = output.max(1)
            sampled_ids.append(predicted)
            
            # 예측된 단어를 다음 입력으로 사용
            input_tokens = torch.cat([input_tokens, predicted.unsqueeze(1)], dim=1)
        
        sampled_ids = torch.stack(sampled_ids, 1)
        return sampled_ids
 

In [5]:
with open(params['vocab_path'], 'rb') as f:
        vocab = pickle.load(f)
transform = transforms.Compose([ 
        transforms.RandomCrop(params['image_size']),
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])

df=pd.read_csv(params['data_path']+params['train_csv'])
train_list=torch.zeros(len(df),3,params['image_size'],params['image_size'])
for i in tqdm(range(len(df))):
    image=transform(Image.open(glob(params['data_path']+'**/**/'+df.loc[i]['path'])[0]).resize((params['image_size'],params['image_size'])))
    train_list[i]=image
df=pd.read_csv(params['data_path']+params['val_csv'])
test_list=torch.zeros(len(df),3,params['image_size'],params['image_size'])
for i in tqdm(range(len(df))):
    image=transform(Image.open(glob(params['data_path']+'**/**/'+df.loc[i]['path'])[0]).resize((params['image_size'],params['image_size'])))
    test_list[i]=image
train_dataset=CustomDataset(train_list,params['data_path'],params['image_size'],params['train_csv'],'train',vocab,transform=transform)
test_dataset=CustomDataset(test_list,params['data_path'],params['image_size'],params['val_csv'],'val',vocab,transform=transform)
train_dataloader=DataLoader(train_dataset,batch_size=params['batch_size'],shuffle=True,collate_fn=collate_fn)
val_dataloader=DataLoader(test_dataset,batch_size=params['batch_size'],shuffle=True,collate_fn=collate_fn)

100%|██████████| 1651/1651 [01:00<00:00, 27.10it/s]
100%|██████████| 206/206 [00:07<00:00, 28.74it/s]


In [6]:

Feature_Extractor=FeatureExtractor()
encoder = AttentionMILModel(params['embed_size'], 1280, Feature_Extractor).to(device)
decoder = DecoderTransformer(params['embed_size'], len(vocab), 15, params['hidden_size'], params['num_layers']).to(device)

criterion = nn.CrossEntropyLoss()
model_param = list(decoder.parameters()) + list(encoder.parameters())
optimizer = torch.optim.Adam(model_param, lr=params['lr'], betas=(params['beta1'], params['beta2']))
# summary(encoder, input_size=(params['batch_size'], 3, params['image_size'], params['image_size']))

In [None]:

plt_count=0
sum_loss= 1000.0
scheduler = 0.90
teacher_forcing=0.3
for epoch in range(params['epochs']):
    train=tqdm(train_dataloader)
    count=0
    train_loss = 0.0
    for images,captions,lengths in train:
        count+=1
        images = images.to(device)
        captions = captions.to(device)
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
        features = encoder(images)
        outputs = decoder(features, captions, lengths, teacher_forcing_ratio=teacher_forcing*(scheduler**epoch))
        outputs = pack_padded_sequence(outputs, lengths, batch_first=True)[0]
        loss = criterion(outputs, targets)
        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss+=loss.item()
        train.set_description(f"train epoch: {epoch+1}/{params['epochs']} Step: {count+1} loss : {train_loss/count:.4f} ")
    with torch.no_grad():
        val_count=0
        val_loss = 0.0 
        val_bleu_loss=0.0
        val=tqdm(val_dataloader)
        for images,captions,lengths in val:
            val_count+=1
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
            features = encoder(images)
            outputs = decoder(features, captions, lengths, teacher_forcing_ratio=0.0)
            outputs = pack_padded_sequence(outputs, lengths, batch_first=True)[0]
            loss = criterion(outputs, targets)
            val_loss+=loss.item()
            val.set_description(f"val epoch: {epoch+1}/{params['epochs']} Step: {val_count+1} loss : {(val_loss/val_count):.4f} ")
    if val_loss<sum_loss:
        sum_loss=val_loss
        torch.save(encoder.state_dict(), '../../model/captioning/BR_encoder_check.pth')
        torch.save(decoder.state_dict(), '../../model/captioning/BR_decoder_check.pth')
        

train epoch: 1/10000 Step: 208 loss : 2.7580 : 100%|██████████| 207/207 [01:23<00:00,  2.48it/s]
val epoch: 1/10000 Step: 27 loss : 3.6710 : 100%|██████████| 26/26 [00:04<00:00,  5.97it/s]
train epoch: 2/10000 Step: 208 loss : 2.0144 : 100%|██████████| 207/207 [01:22<00:00,  2.50it/s]
val epoch: 2/10000 Step: 27 loss : 3.6341 : 100%|██████████| 26/26 [00:05<00:00,  4.83it/s]
train epoch: 3/10000 Step: 208 loss : 1.9543 : 100%|██████████| 207/207 [01:19<00:00,  2.60it/s]
val epoch: 3/10000 Step: 27 loss : 3.1975 : 100%|██████████| 26/26 [00:04<00:00,  6.49it/s]
train epoch: 4/10000 Step: 208 loss : 1.8834 : 100%|██████████| 207/207 [01:17<00:00,  2.66it/s]
val epoch: 4/10000 Step: 27 loss : 2.8748 : 100%|██████████| 26/26 [00:04<00:00,  6.41it/s]
train epoch: 5/10000 Step: 208 loss : 1.8887 : 100%|██████████| 207/207 [01:17<00:00,  2.66it/s]
val epoch: 5/10000 Step: 27 loss : 2.8403 : 100%|██████████| 26/26 [00:03<00:00,  6.54it/s]
train epoch: 6/10000 Step: 208 loss : 1.8662 : 100%|███

In [None]:
outputs