In [1]:
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
import nltk
import pandas as pd
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
import csv
import torchvision.models as models
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from nltk.translate.bleu_score import sentence_bleu
import timm
import random
from torchinfo import summary
from glob import glob
from torchvision.transforms import ToTensor
nltk.download('punkt')
tf = ToTensor()
# Device configurationresul
device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')

[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
params={'image_size':512,
        'lr':2e-4,
        'beta1':0.5,
        'beta2':0.999,
        'batch_size':16,
        'epochs':10000,
        'data_path':'../../data/origin_type/',
        'train_csv':'BR_train.csv',
        'val_csv':'BR_test.csv',
        'vocab_path':'../../data/origin_type/BR_vocab.pkl',
        'embed_size':300,
        'hidden_size':256,
        'num_layers':1,}

In [3]:
class CustomDataset(Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self,data_list, data_path,image_size, csv, class_dataset, vocab, transform=None):
        """Set the path for images, captions and vocabulary wrapper.
        
        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.root = data_path+'**/**/'
        self.df = pd.read_csv(data_path+csv)
        self.class_dataset=class_dataset
        self.vocab = vocab
        self.transform = transform
        self.image_size=image_size
        self.data_list=data_list
    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        df = self.df
        vocab = self.vocab
        img_id=df.loc[index]
        
        caption=img_id['caption']
        images = self.data_list[index]
        # Convert caption (string) to word ids.
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return images, target

    def __len__(self):
        return len(self.data_list)
    

class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)

def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).
    
    We should build custom collate_fn rather than using default collate_fn, 
    because merging caption (including padding) is not supported in default.

    Args:
        data: list of tuple (image, caption). 
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.

    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets, lengths

def idx2word(vocab, indices):
    sentence = []
    
    aa=indices.cpu().numpy()
    
    for index in aa:
        word = vocab.idx2word[index]
        sentence.append(word)
    return sentence
def word2sentence(words_list):
    sentence=''
    for word in words_list:
        if word.isalnum():
            sentence+=' '+word
        else:
            sentence+=word
    return sentence

In [4]:

class FeatureExtractor(nn.Module):
    """Feature extoractor block"""
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        cnn1= timm.create_model('maxvit_tiny_tf_512', pretrained=True)
        self.feature_ex = nn.Sequential(*list(cnn1.children())[:-1])

    def forward(self, inputs):
        features = self.feature_ex(inputs)
        
        return features
    
class AttentionMILModel(nn.Module):
    def __init__(self, num_classes, image_feature_dim,feature_extractor_scale1: FeatureExtractor):
        super(AttentionMILModel, self).__init__()
        self.num_classes = num_classes
        self.image_feature_dim = image_feature_dim

        # Remove the classification head of the CNN model
        self.feature_extractor = feature_extractor_scale1
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(image_feature_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )
        
        # Classification layer
        self.classification_layer = nn.Linear(image_feature_dim, num_classes)

    def forward(self, inputs):
        batch_size, channels, height, width = inputs.size()
        
        # Flatten the inputs
        inputs = inputs.view(-1, channels, height, width)
        
        # Feature extraction using the pre-trained CNN
        features = self.feature_extractor(inputs)  # Shape: (batch_size , 2048, 1, 1)
        
        # Reshape features
        features = features.view(batch_size, -1)  # Shape: (batch_size, num_tiles, 2048)
        
        
        
        
        
        # Classification layer
        logits = self.classification_layer(features)  # Shape: (batch_size, num_classes)
        
        return logits  

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=100):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) # change for LSTM or RNN
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seg_length = max_seq_length
        
    def forward(self, features, captions, lengths, teacher_forcing_ratio=0.5):
        """Decode image feature vectors and generates captions."""
        batch_size = features.size(0)
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        
        # Initialize hidden and cell states
        h, c = None, None
        
        outputs = torch.zeros(batch_size, captions.size(1), self.linear.out_features).to(features.device)
        
        # Iterate over the sequence length
        for t in range(captions.size(1)):
            if t == 0:
                # First input is the image features
                input_t = embeddings[:, t, :].unsqueeze(1)
            else:
                # Decide if we are going to use teacher forcing or not
                use_teacher_forcing = True if torch.rand(1).item() < teacher_forcing_ratio else False
                if use_teacher_forcing:
                    # Use the actual next word as the next input
                    input_t = embeddings[:, t, :].unsqueeze(1)
                else:
                    # Use the predicted word as the next input
                    input_t = self.embed(predicted).unsqueeze(1)
            
            # Get the LSTM outputs
            output, (h, c) = self.lstm(input_t, (h, c)) if h is not None else self.lstm(input_t)
            output = self.linear(output.squeeze(1))
            outputs[:, t, :] = output
            
            # Get the predicted word
            _, predicted = output.max(1)
        
        return outputs

    
    def sample(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)          # hiddens: (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))            # outputs:  (batch_size, vocab_size)
            _, predicted = outputs.max(1)                        # predicted: (batch_size)
            sampled_ids.append(predicted)

            inputs = self.embed(predicted)                       # inputs: (batch_size, embed_size)
            inputs = inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)
        sampled_ids = torch.stack(sampled_ids, 1)                # sampled_ids: (batch_size, max_seq_length)
        return sampled_ids
    
    def stochastic_sample(self, features, temperature, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seg_length):
            hiddens, states = self.lstm(inputs, states)          # hiddens: (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))            # outputs:  (batch_size, vocab_size)
            
            soft_out = F.softmax(outputs/temperature, dim=1)
            predicted = torch.multinomial(soft_out, 1).view(1)
            
            sampled_ids.append(predicted)

            inputs = self.embed(predicted)                       # inputs: (batch_size, embed_size)
            inputs = inputs.unsqueeze(1)                         # inputs: (batch_size, 1, embed_size)
        sampled_ids = torch.stack(sampled_ids, 1)                # sampled_ids: (batch_size, max_seq_length)
        return sampled_ids
    def beam_search_sample(self, features, states=None, beam_width=3):
        """Beam Search를 사용하여 캡션을 생성합니다."""
        inputs = features.unsqueeze(1)  # (batch_size, 1, feature_size)
        batch_size = inputs.size(0)
        
        # 각 배치마다 beam_width만큼의 후보를 관리
        sequences = [[([], 1.0, states)] for _ in range(batch_size)]  # 배치마다 별도의 시퀀스를 관리

        for _ in range(self.max_seg_length):
            all_candidates = [[] for _ in range(batch_size)]
            
            for batch_idx in range(batch_size):
                for seq, score, states in sequences[batch_idx]:
                    hiddens, states = self.lstm(inputs[batch_idx].unsqueeze(0), states)
                    outputs = self.linear(hiddens.squeeze(1))  # (1, vocab_size)
                    soft_out = F.softmax(outputs, dim=1)  # (1, vocab_size)
                    
                    top_k_probs, top_k_words = soft_out.topk(beam_width, dim=1)
                    
                    for i in range(beam_width):
                        candidate_seq = seq + [top_k_words[0][i].item()]
                        candidate_score = score * top_k_probs[0][i].item()
                        candidate_states = (states[0].clone(), states[1].clone())  # 상태 복사
                        
                        all_candidates[batch_idx].append((candidate_seq, candidate_score, candidate_states))
            
            # 상위 beam_width 개의 시퀀스를 선택
            for batch_idx in range(batch_size):
                ordered = sorted(all_candidates[batch_idx], key=lambda x: x[1], reverse=True)
                sequences[batch_idx] = ordered[:beam_width]
            
            # 다음 LSTM 입력 설정
            inputs_list = []
            for batch_idx in range(batch_size):
                inputs_list.append([seq[-1] for seq, _, _ in sequences[batch_idx]])
            inputs = torch.tensor(inputs_list, device=features.device)
            inputs = self.embed(inputs)  # (batch_size, beam_width, embed_size)
            inputs = inputs.view(batch_size * beam_width, 1, -1)  # (batch_size * beam_width, 1, embed_size)
        
        # 배치별로 가장 높은 점수를 받은 시퀀스를 선택
        best_sequences = [sequences[batch_idx][0][0] for batch_idx in range(batch_size)]
        
        # (batch_size, max_seq_length) 형태로 변환하여 반환
        return torch.tensor(best_sequences, device=features.device)

In [5]:
with open(params['vocab_path'], 'rb') as f:
        vocab = pickle.load(f)
transform = transforms.Compose([ 
        transforms.RandomCrop(params['image_size']),
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])

df=pd.read_csv(params['data_path']+params['train_csv'])
train_list=torch.zeros(len(df),3,params['image_size'],params['image_size'])
for i in tqdm(range(len(df))):
    image=transform(Image.open(glob(params['data_path']+'**/**/'+df.loc[i]['path'])[0]).resize((params['image_size'],params['image_size'])))
    train_list[i]=image
df=pd.read_csv(params['data_path']+params['val_csv'])
test_list=torch.zeros(len(df),3,params['image_size'],params['image_size'])
for i in tqdm(range(len(df))):
    image=transform(Image.open(glob(params['data_path']+'**/**/'+df.loc[i]['path'])[0]).resize((params['image_size'],params['image_size'])))
    test_list[i]=image
train_dataset=CustomDataset(train_list,params['data_path'],params['image_size'],params['train_csv'],'train',vocab,transform=transform)
test_dataset=CustomDataset(test_list,params['data_path'],params['image_size'],params['val_csv'],'val',vocab,transform=transform)
train_dataloader=DataLoader(train_dataset,batch_size=params['batch_size'],shuffle=True,collate_fn=collate_fn)
val_dataloader=DataLoader(test_dataset,batch_size=params['batch_size'],shuffle=True,collate_fn=collate_fn)

100%|██████████| 7343/7343 [31:26<00:00,  3.89it/s]  
100%|██████████| 1836/1836 [08:17<00:00,  3.69it/s]


In [6]:

Feature_Extractor=FeatureExtractor()
encoder = AttentionMILModel(params['embed_size'],131072,Feature_Extractor).to(device)
decoder = DecoderRNN(params['embed_size'], params['hidden_size'], len(vocab), params['num_layers']).to(device)
criterion = nn.CrossEntropyLoss()
model_param = list(decoder.parameters()) + list(encoder.parameters())
optimizer = torch.optim.Adam(model_param, lr=params['lr'], betas=(params['beta1'], params['beta2']))
# summary(encoder, input_size=(params['batch_size'], 3, params['image_size'], params['image_size']))

In [7]:

plt_count=0
sum_loss= 1000.0
scheduler = 0.90
teacher_forcing=0.5
for epoch in range(params['epochs']):
    train=tqdm(train_dataloader)
    count=0
    train_loss = 0.0
    for images,captions,lengths in train:
        count+=1
        images = images.to(device)
        captions = captions.to(device)
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
        features = encoder(images)
        outputs = decoder(features, captions, lengths, teacher_forcing_ratio=teacher_forcing*(scheduler**epoch))
        outputs = pack_padded_sequence(outputs, lengths, batch_first=True)[0]
        loss = criterion(outputs, targets)
        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss+=loss.item()
        train.set_description(f"epoch: {epoch+1}/{params['epochs']} Step: {count+1} loss : {train_loss/count:.4f} ")
    with torch.no_grad():
        val_count=0
        val_loss = 0.0 
        val_bleu_loss=0.0
        val=tqdm(val_dataloader)
        for images,captions,lengths in val:
            val_count+=1
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]
            features = encoder(images)
            outputs = decoder(features, captions, lengths, teacher_forcing_ratio=teacher_forcing*(scheduler**epoch))
            outputs = pack_padded_sequence(outputs, lengths, batch_first=True)[0]
            loss = criterion(outputs, targets)
            val_loss+=loss.item()
            val.set_description(f"epoch: {epoch+1}/{params['epochs']} Step: {val_count+1} loss : {val_loss/val_count:.4f} ")
    torch.save(encoder.state_dict(), f'../../model/captioning/maxvit_BR_encoder_{epoch+1}.pth')
    torch.save(decoder.state_dict(), f'../../model/captioning/maxvit_BR_decoder_{epoch+1}.pth')
    if val_loss<sum_loss:
        sum_loss=val_loss
        torch.save(encoder.state_dict(), '../../model/captioning/maxvit_BR_encoder_check.pth')
        torch.save(decoder.state_dict(), '../../model/captioning/maxvit_BR_decoder_check.pth')
        

epoch: 1/10000 Step: 460 loss : 2.6152 : 100%|██████████| 459/459 [09:44<00:00,  1.27s/it]
epoch: 1/10000 Step: 116 loss : 1.0770 : 100%|██████████| 115/115 [01:21<00:00,  1.40it/s]
epoch: 2/10000 Step: 460 loss : 0.5905 : 100%|██████████| 459/459 [09:27<00:00,  1.24s/it]
epoch: 2/10000 Step: 116 loss : 0.3139 : 100%|██████████| 115/115 [01:05<00:00,  1.75it/s]
epoch: 3/10000 Step: 460 loss : 0.2653 : 100%|██████████| 459/459 [08:57<00:00,  1.17s/it]
epoch: 3/10000 Step: 116 loss : 0.2029 : 100%|██████████| 115/115 [01:11<00:00,  1.61it/s]
epoch: 4/10000 Step: 460 loss : 0.2313 : 100%|██████████| 459/459 [08:57<00:00,  1.17s/it]
epoch: 4/10000 Step: 116 loss : 0.2038 : 100%|██████████| 115/115 [01:09<00:00,  1.66it/s]
epoch: 5/10000 Step: 460 loss : 0.2684 : 100%|██████████| 459/459 [08:57<00:00,  1.17s/it]
epoch: 5/10000 Step: 116 loss : 0.2376 : 100%|██████████| 115/115 [01:09<00:00,  1.66it/s]
epoch: 6/10000 Step: 460 loss : 0.2790 : 100%|██████████| 459/459 [08:52<00:00,  1.16s/it]

KeyboardInterrupt: 