<a href="https://colab.research.google.com/github/CopotronicRifat/CSE-437-PATTERN-RECOGNITION/blob/master/CAPTMFN_50.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/openai/CLIP.git
!git clone https://github.com/jefferyYu/TomBERT.git
!git clone https://github.com/Porky-Pig/TwitterImageData.git

import os
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import torch
from torch import nn, optim
from transformers import RobertaModel, RobertaTokenizer, BertConfig
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import clip
from sklearn.metrics import f1_score

# Define paths to the data
tsv_base_path = 'TomBERT/absa_data/twitter2015'
image_base_path = 'TwitterImageData/twitter2015_images'

# Verify paths
print("TSV Directory exists:", os.path.exists(tsv_base_path))
print("Image Directory exists:", os.path.exists(image_base_path))

# Define the path to the base directory containing the TSV files
columns = ['index', 'Label', 'ImageID', 'String1', 'String2']

# Function to load and prepare data with the correct number of columns
def load_and_prepare_data(filename):
    file_path = os.path.join(tsv_base_path, filename)
    return pd.read_csv(file_path, sep='\t', header=0, names=columns)

# Load the data files
train_df = load_and_prepare_data('train.tsv')
dev_df = load_and_prepare_data('dev.tsv')
test_df = load_and_prepare_data('test.tsv')

# Combine train and dev sets
full_train_df = pd.concat([train_df, dev_df])

# Split into new train and validation sets
train_df, valid_df = train_test_split(full_train_df, test_size=0.1)

# Initialize OpenAI CLIP model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clip_model, preprocess = clip.load("ViT-B/32", device=device)

# Custom Dataset class
class CustomDataset(Dataset):
    def __init__(self, dataframe, image_base_path, transform=None):
        self.dataframe = dataframe
        self.image_base_path = image_base_path
        self.transform = transform
        self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]

        tweet = row['String1']
        aspect_term = row['String2']

        inputs = self.tokenizer(tweet, return_tensors="pt", padding='max_length', max_length=50, truncation=True)
        input_ids = inputs['input_ids'].squeeze(0)
        attention_mask = inputs['attention_mask'].squeeze(0)

        aspect_inputs = self.tokenizer(aspect_term, return_tensors="pt", padding='max_length', max_length=20, truncation=True)
        aspect_ids = aspect_inputs['input_ids'].squeeze(0)
        aspect_attention_mask = aspect_inputs['attention_mask'].squeeze(0)

        image_path = os.path.join(self.image_base_path, row['ImageID'])
        image = Image.open(image_path).convert('RGB')

        # Generate caption using CLIP
        image_input = preprocess(image).unsqueeze(0).to(device)
        with torch.no_grad():
            text_inputs = clip.tokenize(["This is a photo of"]).to(device)
            logits_per_image, logits_per_text = clip_model(image_input, text_inputs)
            caption = clip.tokenize("This is a photo of").to(device)

        caption_inputs = self.tokenizer.decode(caption.squeeze(0).cpu().numpy())
        caption_inputs = self.tokenizer(caption_inputs, return_tensors="pt", padding='max_length', max_length=50, truncation=True)
        caption_ids = caption_inputs['input_ids'].squeeze(0)
        caption_attention_mask = caption_inputs['attention_mask'].squeeze(0)

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(row['Label'], dtype=torch.long)

        return input_ids, attention_mask, aspect_ids, aspect_attention_mask, image, caption_ids, caption_attention_mask, label

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

# Create DataLoader
train_dataset = CustomDataset(train_df, image_base_path, transform=transform)
valid_dataset = CustomDataset(valid_df, image_base_path, transform=transform)
test_dataset = CustomDataset(test_df, image_base_path, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

class TMFN(nn.Module):
    def __init__(self):
        super(TMFN, self).__init__()
        self.text_encoder = RobertaModel.from_pretrained('roberta-base')
        self.image_encoder = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        self.dropout = nn.Dropout(p=0.3)
        self.fusion_layer = nn.Linear(768 + 1000 + 768, 512)
        self.classifier = nn.Linear(512, 3)

        # Caption generation components
        decoder_config = BertConfig.from_pretrained('roberta-base')
        self.caption_decoder = RobertaModel(decoder_config)
        self.caption_linear = nn.Linear(768, self.text_encoder.config.vocab_size)

    def forward(self, input_ids, attention_mask, aspect_ids, aspect_attention_mask, images, caption_ids, caption_attention_mask):
        text_features = self.text_encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        aspect_features = self.text_encoder(aspect_ids, attention_mask=aspect_attention_mask).last_hidden_state[:, 0, :]
        image_features = self.image_encoder(images)
        combined_features = torch.cat([text_features, aspect_features, image_features], dim=1)
        fusion_output = torch.relu(self.fusion_layer(combined_features))
        fusion_output = self.dropout(fusion_output)
        logits = self.classifier(fusion_output)

        # Caption generation
        caption_outputs = self.caption_decoder(input_ids=caption_ids, attention_mask=caption_attention_mask)
        caption_logits = self.caption_linear(caption_outputs.last_hidden_state)

        return logits, caption_logits

def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data in dataloader:
            input_ids, attention_mask, aspect_ids, aspect_attention_mask, images, caption_ids, caption_attention_mask, labels = [d.to(device) for d in data]
            outputs, _ = model(input_ids, attention_mask, aspect_ids, aspect_attention_mask, images, caption_ids, caption_attention_mask)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    f1 = f1_score(all_labels, all_preds, average='macro')
    return f1

model = TMFN().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()
caption_criterion = nn.CrossEntropyLoss(ignore_index=model.text_encoder.config.pad_token_id)

best_valid_f1 = 0
classification_weight = 0.5
caption_weight = 0.5

for epoch in range(50):
    model.train()
    total_loss = 0
    for data in tqdm(train_loader):
        input_ids, attention_mask, aspect_ids, aspect_attention_mask, images, caption_ids, caption_attention_mask, labels = [d.to(device) for d in data]
        outputs, caption_logits = model(input_ids, attention_mask, aspect_ids, aspect_attention_mask, images, caption_ids, caption_attention_mask)

        loss = classification_weight * criterion(outputs, labels)
        caption_loss = caption_weight * caption_criterion(caption_logits.view(-1, caption_logits.size(-1)), caption_ids.view(-1))
        total_loss = loss + caption_loss

        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_loss += loss.item()

    valid_f1 = evaluate_model(model, valid_loader, device)
    print(f'Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}, Validation F1: {valid_f1}')

    if valid_f1 > best_valid_f1:
        best_valid_f1 = valid_f1
        torch.save(model.state_dict(), 'best_model.pth')

# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

# Evaluate on the test set
test_f1 = evaluate_model(model, test_loader, device)
print(f'Test F1 Score: {test_f1}')


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-xrmncvsl
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-xrmncvsl
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25ldone
fatal: destination path 'TomBERT' already exists and is not an empty directory.
fatal: destination path 'TwitterImageData' already exists and is not an empty directory.
TSV Directory exists: True
Image Directory exists: True


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
You are using a model of type roberta to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 1, Loss: 0.0008295896695926785, Validation F1: 0.6349194764754538


100%|██████████| 484/484 [04:45<00:00,  1.69it/s]


Epoch 2, Loss: 0.00048632282414473593, Validation F1: 0.6982129390217625


100%|██████████| 484/484 [04:45<00:00,  1.69it/s]


Epoch 3, Loss: 0.0004549459263216704, Validation F1: 0.7191830767522575


100%|██████████| 484/484 [04:45<00:00,  1.70it/s]


Epoch 4, Loss: 0.00015855552919674665, Validation F1: 0.7589695574428399


100%|██████████| 484/484 [04:48<00:00,  1.68it/s]


Epoch 5, Loss: 0.00010455430310685188, Validation F1: 0.7228486873300232


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 6, Loss: 0.0003452126693446189, Validation F1: 0.7173499888419679


100%|██████████| 484/484 [04:45<00:00,  1.70it/s]


Epoch 7, Loss: 9.623046025808435e-06, Validation F1: 0.7345761081094148


100%|██████████| 484/484 [04:45<00:00,  1.69it/s]


Epoch 8, Loss: 3.6708031984744594e-05, Validation F1: 0.7265483934555489


100%|██████████| 484/484 [04:45<00:00,  1.69it/s]


Epoch 9, Loss: 4.473803073778981e-06, Validation F1: 0.7533645870319217


100%|██████████| 484/484 [04:45<00:00,  1.70it/s]


Epoch 10, Loss: 2.1982890757499263e-05, Validation F1: 0.7048462969040585


100%|██████████| 484/484 [04:47<00:00,  1.68it/s]


Epoch 11, Loss: 7.256073786265915e-06, Validation F1: 0.728678835310906


100%|██████████| 484/484 [04:44<00:00,  1.70it/s]


Epoch 12, Loss: 2.6760742457554443e-06, Validation F1: 0.7297536150533306


100%|██████████| 484/484 [04:44<00:00,  1.70it/s]


Epoch 13, Loss: 0.00012894318206235766, Validation F1: 0.7079488639673462


100%|██████████| 484/484 [04:44<00:00,  1.70it/s]


Epoch 14, Loss: 4.162201094004558e-06, Validation F1: 0.7429977576087378


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 15, Loss: 3.6966216612199787e-06, Validation F1: 0.7302840038948427


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 16, Loss: 1.1011957212758716e-06, Validation F1: 0.7002001918736518


100%|██████████| 484/484 [04:45<00:00,  1.69it/s]


Epoch 17, Loss: 3.215950982848881e-06, Validation F1: 0.7366499136583063


100%|██████████| 484/484 [04:47<00:00,  1.68it/s]


Epoch 18, Loss: 4.1866744140861556e-05, Validation F1: 0.7291937311976014


100%|██████████| 484/484 [04:47<00:00,  1.68it/s]


Epoch 19, Loss: 6.71169516408554e-07, Validation F1: 0.7321018761543696


100%|██████████| 484/484 [04:45<00:00,  1.70it/s]


Epoch 20, Loss: 2.2682881990476744e-06, Validation F1: 0.7161946259985476


100%|██████████| 484/484 [04:44<00:00,  1.70it/s]


Epoch 21, Loss: 2.7094951292383485e-05, Validation F1: 0.7207270243424766


100%|██████████| 484/484 [04:45<00:00,  1.70it/s]


Epoch 22, Loss: 6.000641405989882e-07, Validation F1: 0.7181553477690289


100%|██████████| 484/484 [04:45<00:00,  1.69it/s]


Epoch 23, Loss: 6.886801173777712e-08, Validation F1: 0.7333816370142762


100%|██████████| 484/484 [04:45<00:00,  1.70it/s]


Epoch 24, Loss: 1.6114514664877788e-06, Validation F1: 0.7182230467944754


100%|██████████| 484/484 [04:45<00:00,  1.70it/s]


Epoch 25, Loss: 3.0294188491097884e-07, Validation F1: 0.734823320928584


100%|██████████| 484/484 [04:45<00:00,  1.70it/s]


Epoch 26, Loss: 9.634550951886922e-05, Validation F1: 0.7159037893372572


100%|██████████| 484/484 [04:45<00:00,  1.69it/s]


Epoch 27, Loss: 3.764264704386733e-07, Validation F1: 0.6995168626464608


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 28, Loss: 1.7827267129177926e-06, Validation F1: 0.7142484653435842


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 29, Loss: 6.483476511220942e-08, Validation F1: 0.732127317640686


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 30, Loss: 7.4396443778823595e-06, Validation F1: 0.7178080387856897


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 31, Loss: 3.8790490179962944e-06, Validation F1: 0.7202515836178199


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 32, Loss: 3.7069586511506714e-08, Validation F1: 0.7359721201230635


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 33, Loss: 1.666096068220213e-06, Validation F1: 0.7073173813545037


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 34, Loss: 1.6040936543504358e-06, Validation F1: 0.6940103666115659


100%|██████████| 484/484 [04:47<00:00,  1.68it/s]


Epoch 35, Loss: 0.0002375784097239375, Validation F1: 0.7220474960009845


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 36, Loss: 2.446004430112225e-07, Validation F1: 0.7497917082272316


100%|██████████| 484/484 [04:48<00:00,  1.68it/s]


Epoch 37, Loss: 8.825801387501997e-07, Validation F1: 0.7505729014620867


100%|██████████| 484/484 [04:51<00:00,  1.66it/s]


Epoch 38, Loss: 1.2616567346412921e-07, Validation F1: 0.7396985852431396


100%|██████████| 484/484 [04:51<00:00,  1.66it/s]


Epoch 39, Loss: 4.313857516535791e-06, Validation F1: 0.735254170753804


100%|██████████| 484/484 [04:50<00:00,  1.66it/s]


Epoch 40, Loss: 1.3557186093748896e-06, Validation F1: 0.7375979929120179


100%|██████████| 484/484 [04:50<00:00,  1.67it/s]


Epoch 41, Loss: 3.138759439025307e-06, Validation F1: 0.7396123633205699


100%|██████████| 484/484 [04:50<00:00,  1.67it/s]


Epoch 42, Loss: 0.0013165527489036322, Validation F1: 0.7457433027371835


100%|██████████| 484/484 [04:50<00:00,  1.66it/s]


Epoch 43, Loss: 8.808227721601725e-05, Validation F1: 0.7334069677554603


100%|██████████| 484/484 [04:51<00:00,  1.66it/s]


Epoch 44, Loss: 6.983396474424808e-07, Validation F1: 0.7277967834532656


100%|██████████| 484/484 [04:51<00:00,  1.66it/s]


Epoch 45, Loss: 1.9094193248747615e-06, Validation F1: 0.7236427967891846


100%|██████████| 484/484 [04:47<00:00,  1.68it/s]


Epoch 46, Loss: 1.0706709190344554e-06, Validation F1: 0.7466362797123685


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 47, Loss: 2.4775399509735507e-08, Validation F1: 0.7424501921379446


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 48, Loss: 0.0007479292689822614, Validation F1: 0.7279840353622796


100%|██████████| 484/484 [04:47<00:00,  1.68it/s]


Epoch 49, Loss: 0.00010704964370233938, Validation F1: 0.6852249059351863


100%|██████████| 484/484 [04:46<00:00,  1.69it/s]


Epoch 50, Loss: 9.002966407933854e-07, Validation F1: 0.7492087112622827
Test F1 Score: 0.7407656636499418
