In [1]:
import torch
from torch.utils.data import Dataset,DataLoader
import numpy as np
import os
import copy
import cv2
import json
import transformers
from transformers import AutoImageProcessor, ViTModel
from PIL import Image, ImageMath
from transformers import AutoModel
import torch.nn as nn
from transformers import AutoTokenizer, BertModel, BertTokenizer
from torch.optim import Adam
from tqdm import tqdm
import torch.nn.functional as F
from ast import literal_eval
from torchvision.models import resnet50,SwinTransformer
import torchvision
from transformers import AutoFeatureExtractor, ResNetForImageClassification
from transformers import AutoImageProcessor, AutoModelForImageClassification
import requests
from transformers import RobertaTokenizer, RobertaModel
from focal_loss.focal_loss import FocalLoss
# from focal_loss_pytorch.focal_loss_pytorch.focal_loss import BinaryFocalLoss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
data_dir = '/home/server-ailab-12gb/DUC-MMM/data/Musti'
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')
train_img_dir = os.path.join(train_dir, 'img')
train_file_path = os.path.join(train_dir, 'aug_train.json')

test_img_dir = os.path.join(test_dir, 'img')
test_file_path = os.path.join(test_dir, 'test.json')
checkpoint_directory = "/home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_FOCAL_2_VIT_BERT"
CHECKPOINT_EXTENSION = '.pt'

save_path = '/home/server-ailab-12gb/DUC-MMM/save/T_W_FOCAL_2_VIT_BERT'

In [4]:
def save_checkpoint(checkpoint_directory, epoch, model, optimizer, LOSS, checkpoint_name=None):
    """
    The checkpoint will be saved in `checkpoint_directory` with name `checkpoint_name`.
    If `checkpoint_name` is None, the checkpoint will be saved with name `next_checkpoint_name_id + epoch`.
    """
    if checkpoint_directory is not None:
        if checkpoint_name is None:
            checkpoint_name = f'{epoch}{CHECKPOINT_EXTENSION}'

        path = os.path.join(checkpoint_directory, checkpoint_name)
        print("path: "+str(path))
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': LOSS,
                    }, path)
def load_checkpoint(PATH, model, optimizer):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    return epoch, model, optimizer, loss
def read_json(parient_dir, name=None):
    # Opening JSON file
    if name == None:
        path = parient_dir
    else:
        path = os.path.join(parient_dir, str(name) + ".json")
    with open(path, 'r') as openfile:

        # Reading from json file
        json_object = json.load(openfile)
    return json_object
def write_json(dict, parient_dir, name=None):
    # Serializing json
    json_object = json.dumps(dict, indent=4)
    if name != None:
        path = os.path.join(parient_dir, str(name)+".json")
    else:
        path = parient_dir
    # Writing to sample.json
    with open(path, "w") as outfile:
        outfile.write(json_object)
# img_dir = '/content/drive/MyDrive/HaruLab/Dataset/Musti/test/img'
# json_file =
# values = json_file['pairs']

def load_image_from_url(url):
    try:
        response = requests.get(url)
        response.raise_for_status()

        # Open the image using Pillow
        image = Image.open(BytesIO(response.content))
        return image
    except requests.exceptions.RequestException as e:
        print(f"Error loading image from URL: {e}")
        return None

## . Dataset&DataLoader

In [5]:
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
# tokenizer = AutoModel.from_pretrained("google/mt5-base")
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

In [6]:
class Dataset_MMM(Dataset):
    def __init__(self, file_path,image_dir ,image_processor,tokenizer, test=False):
        super(Dataset_MMM, self).__init__()
        self.image_dir = image_dir
        data = read_json(file_path)
        # print(data)
        print()
        self.data = []
        if test==False:
          for obj in data:
            if os.path.exists(os.path.join(self.image_dir, obj['image'])):
              self.data.append(obj)
        else:
          for obj in data:
            _, base_name, file_extension = split_filename(obj['image'])
            if os.path.exists(os.path.join(self.image_dir, base_name + file_extension)):
              obj['image'] = base_name + file_extension
              self.data.append(obj)
            else:
              print(obj['image'])
        self.image_processor = image_processor
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir,self.data[index]['image'])
        text = self.data[index]['text']
        target = self.data[index]['subtask1_label']
        if target == 'NO':
          target = 0
        else:
          target = 1
        # process_image
        img = Image.open(img_path)
        if img.mode == 'L':
          img = img.convert('RGB')
        img = self.image_processor(images = img, return_tensors="pt",padding=True)
        # process_text
        text = self.tokenizer(text,padding="max_length", max_length=197,truncation=True, return_tensors="pt")

        return img, text, target

## Model

### Image

In [8]:
vit_encoder = ViTModel.from_pretrained("google/vit-large-patch32-384")
resnet_encoder = ResNetForImageClassification.from_pretrained("microsoft/resnet-152")
resnet_encoder = nn.Sequential(*list(resnet_encoder.children())[:-1])

In [9]:
swintransformer = AutoModelForImageClassification.from_pretrained("microsoft/swinv2-large-patch4-window12-192-22k")
swintransformer = nn.Sequential(*list(swintransformer.children())[:-1])

In [10]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
early_stopper = EarlyStopper(patience=5, min_delta=0)

### Text

In [11]:
# text_encoder = BertModel.from_pretrained("google/mt5-base").to(device)
text_encoder = RobertaModel.from_pretrained('bert-base-multilingual-cased')
encoder_layer = text_encoder.encoder
text_embedding= text_encoder.embeddings
pooling_layer= text_encoder.pooler

You are using a model of type bert to instantiate a model of type roberta. This is not supported for all configurations of models and can yield errors.
Some weights of RobertaModel were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['encoder.layer.6.intermediate.dense.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.3.output.dense.bias', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.7.attention.self.key.weig

### CrossModel

In [12]:
class CrossModel(nn.Module):
  def __init__(self,image_encoder,text_encoder,text_embedding, pooling_layer):
    super(CrossModel,self).__init__()
    #image_encoder
    self.vit_encoder = image_encoder[0]
    self.resnet_encoder = image_encoder[1]
    self.swintransformer_encoder = image_encoder[2]

    #text_encoder & text embedding
    self.text_encoder = text_encoder
    self.text_embedding = text_embedding
    self.pooling_layer = pooling_layer

    # layernorm
    self.layernorm_197 = nn.LayerNorm(197)
    self.layernorm_768 = nn.LayerNorm(768)

    # self.layernorm_1024 = nn.LayerNorm(1024)
    #linear
    self.linear = nn.Linear(2,1)

    # projection
    self.vit_projection_1 = nn.Linear(1024,768)
    self.vit_projection_2 = nn.Linear(145,197)
    # self.resnet_projection_2 = nn.Linear(2048,197)
    self.resnet_projection_2 = nn.Linear(2048,197)
    self.resnet_projection_3 = nn.Linear(49,768)

    self.swintransformer_projection_0 = nn.Linear(1536,768)
    self.swintransformer_projection_1 = nn.Linear(49,197)
    self.final_projection = nn.Linear(197,1)
    self.voting = nn.Linear(3,1)

  def forward(self,img,text):
    batch_size = img.size(0)
    #text_features
    text_embedding = self.text_embedding(text) # torch.Size([16, 197, 1024])

    #image_features
    image_features_0 = self.vit_encoder(img).last_hidden_state #torch.Size([16, 197, 768])
    image_features_1 = self.resnet_encoder(img).last_hidden_state #torch.Size([16, 512, 7, 7])
    image_features_2 = self.swintransformer_encoder(img).last_hidden_state #torch.Size([16, 49, 1536])

    dim = image_features_1.size(1)

    #projection_0
    image_features_0 = self.vit_projection_1(image_features_0) #torch.Size([16, 145, 768])
    image_features_0 = self.layernorm_768(image_features_0)
    image_features_0 = image_features_0.permute(0,2,1) #torch.Size([16, 768, 145])
    image_features_0 = self.vit_projection_2(image_features_0) #torch.Size([16, 768, 197])
    image_features_0 = self.layernorm_197(image_features_0)
    image_features_0 = image_features_0.permute(0,2,1) #torch.Size([16, 197, 768])

    #projection_1
    image_features_1 = image_features_1.reshape(batch_size, dim,-1) #torch.Size([16, 512, 49])
    image_features_1 = image_features_1.permute(0,2,1) #torch.Size([16, 49, 512])
    image_features_1 = self.resnet_projection_2(image_features_1) #torch.Size([16, 49, 197])
    image_features_1 = self.layernorm_197(image_features_1)
    image_features_1 = image_features_1.permute(0,2,1) #torch.Size([16, 197, 49])
    image_features_1 = self.resnet_projection_3(image_features_1) #torch.Size([16, 197, 768])
    image_features_1 = self.layernorm_768(image_features_1)

    #projection_2
    image_features_2 = self.swintransformer_projection_0(image_features_2) #torch.Size([16, 197, 768])
    image_features_2 = self.layernorm_768(image_features_2)
    image_features_2 = image_features_2.permute(0,2,1) #torch.Size([16, 768, 49])
    image_features_2 = self.swintransformer_projection_1(image_features_2) #torch.Size([16, 768, 197])
    image_features_2 = self.layernorm_197(image_features_2)
    image_features_2 = image_features_2.permute(0,2,1) #torch.Size([16, 197, 768])


    final_features_0 = self.layernorm_768(self.linear(torch.cat([image_features_0.unsqueeze(-1) , text_embedding.unsqueeze(-1)], dim =-1)).squeeze(-1) + text_embedding)
    final_features_1 = self.layernorm_768(self.linear(torch.cat([image_features_1.unsqueeze(-1) , text_embedding.unsqueeze(-1)], dim =-1)).squeeze(-1) + text_embedding)
    final_features_2 = self.layernorm_768(self.linear(torch.cat([image_features_2.unsqueeze(-1) , text_embedding.unsqueeze(-1)], dim =-1)).squeeze(-1) + text_embedding)


    final_features_0 = self.pooling_layer(self.text_encoder(final_features_0).last_hidden_state)
    final_features_1 = self.pooling_layer(self.text_encoder(final_features_1).last_hidden_state)
    final_features_2 = self.pooling_layer(self.text_encoder(final_features_2).last_hidden_state)


    image_features_0 = self.final_projection(image_features_0.permute(0,2,1)).squeeze(-1)
    image_features_1 = self.final_projection(image_features_1.permute(0,2,1)).squeeze(-1)
    image_features_2 = self.final_projection(image_features_2.permute(0,2,1)).squeeze(-1)

    logits_0 = F.cosine_similarity(final_features_0,image_features_0, dim =-1)
    logits_1 = F.cosine_similarity(final_features_1,image_features_1, dim =-1)
    logits_2 = F.cosine_similarity(final_features_2,image_features_2, dim =-1)
    logits = self.voting(torch.cat([logits_0.unsqueeze(-1), logits_1.unsqueeze(-1), logits_2.unsqueeze(-1)], dim =-1)).squeeze(-1)
    return logits

## Train

In [13]:
def train(train_loader, model,device, epochs=10, total_iterations_limit=None, optimizer=None, cur_epoch=0, best_valid_loss=10000):
    # model = model.to(device)
    # optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    history = {"loss": [],
                "acc": []}
    total_iterations = 0
    # best_valid_loss = 100000
    loss_fn = FocalLoss(gamma=2)
    # loss_fn = BinaryFocalLoss(gamma=5)
    for epoch in range(cur_epoch + 1, epochs):
        model.train()
        avg_acc = 0
        acc_sum = 0
        avg_loss = 0
        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            img = torch.squeeze(data[0]['pixel_values'],1)
            text = torch.squeeze(data[1]['input_ids'],1)
            target = torch.squeeze(data[2])
            # print(text.shape)
            img = img.to(device)
            text = text.to(device)
            target = target.to(device).to(torch.int64)
            optimizer.zero_grad()
            
            logits = model(img,text)
            # print("logits: " + str(torch.sigmoid(logits)))
            # print("target: " + str(target))
            loss = loss_fn(torch.sigmoid(logits), target)
            # print(loss)
            loss_sum += loss.item()

            pred = torch.round(logits)
            correct_pred = (target == pred).float()
            acc_sum += (correct_pred.sum() / len(correct_pred)).cpu().item()
            
            data_iterator.set_postfix(loss=loss_sum / num_iterations)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

        avg_loss = loss_sum / num_iterations
        avg_acc = acc_sum / num_iterations

        if avg_loss < best_valid_loss:
                best_valid_loss = avg_loss
                best_weights = copy.deepcopy(model.state_dict())
                save_checkpoint(checkpoint_directory, epoch, model, optimizer, avg_loss)
        history['loss'].append(avg_loss)
        history['acc'].append(avg_acc)
        print("Accuracy: "+str(avg_acc))
        if early_stopper.early_stop(avg_loss):
                print("Early stop at epoch: "+str(epoch) + " with valid loss: "+str(avg_loss))
                break
    return history

In [None]:
dataset_train = Dataset_MMM(train_file_path, train_img_dir,image_processor,tokenizer)
dataloader_train = DataLoader(dataset_train,batch_size=20,drop_last=True)

In [14]:
img_encoder =[vit_encoder,resnet_encoder,swintransformer]
model = CrossModel(img_encoder,encoder_layer,text_embedding, pooling_layer)

lr = 0.001
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
epoch = 0
loss = 10000
# index_checkpoint = '7'
# PATH = os.path.join(checkpoint_directory, index_checkpoint + CHECKPOINT_EXTENSION)
# epoch, model, optimizer, loss = load_checkpoint(PATH, model, optimizer)
# print("Load checkpoint: "+str(PATH))

history_losses = train(dataloader_train,model,device ,epochs=100, optimizer=optimizer, cur_epoch=epoch, best_valid_loss=loss)
write_json(history_losses, parient_dir=save_path, name='history_losses')

Epoch 1:   8%|▊         | 159/1891 [02:28<26:49,  1.08it/s, loss=0.14] 

## Predict

In [None]:
# def predict(model,device,test_file_path,image_processor):
#   model.eval()
#   json_dict = read_json(test_file_path,name='test')
#   json_values = json_dict['pairs']
#   preds = []
#   labels = []
#   for item in tqdm(json_values):
#     image_url = item['image']
#     text = item['text']
#     subtask1_label = item['subtask1_label']
#     image = load_image_from_url(image_url)
#     if image != None:
#       if image.mode == 'L':
#         image = image.convert('RGB')
#       image = image_processor(images = image, return_tensors="pt",padding=True).to(device)
#       # process_text
#       text = tokenizer(text,padding="max_length", max_length=197,truncation=True, return_tensors="pt").to(device)
#       logits = model(image['pixel_values'], text['input_ids'])
#       pred = torch.sigmoid(logits)
#       print("pred: "+str(pred))
#       if (pred) > 0.5:
#         pred = 'YES'
#         print('YES')
#       else:
#         pred = 'NO'
#       item['subtask1_label'] = pred
#   return json_dict# def predict(model,device,test_file_path,image_processor):
#   model.eval()
#   json_dict = read_json(test_file_path)
#   json_values = json_dict['pairs']
#   preds = []
#   labels = []
#   for item in tqdm(json_values):
#     image_url = item['image']
#     text = item['text']
#     subtask1_label = item['subtask1_label']
#     image = load_image_from_url(image_url)
#     if image != None:
#       if image.mode == 'L':
#         image = image.convert('RGB')
#       image = image_processor(images = image, return_tensors="pt",padding=True).to(device)
#       # process_text
#       text = tokenizer(text,padding="max_length", max_length=197,truncation=True, return_tensors="pt").to(device)
#       logits = model(image['pixel_values'], text['input_ids'])
#       pred = torch.sigmoid(logits)
#       print("pred: "+str(pred))
#       if (pred) > 0.5:
#         pred = 'YES'
#         print('YES')
#       else:
#         pred = 'NO'
#       item['subtask1_label'] = pred
#   return json_dict
def predict(model,device,dataloader,image_processor):
  list = []
  model.eval()

  data_iterator = tqdm(dataloader, desc=f'Epoch')
  with torch.no_grad():
    for data in data_iterator:
      img = torch.squeeze(data[0]['pixel_values'],1)
      text = torch.squeeze(data[1]['input_ids'],1)
      target = torch.squeeze(data[2])
      # print(text.shape)
      img = img.to(device)
      text = text.to(device)
      target = target.to(device).to(torch.int64)
      optimizer.zero_grad()
      
      logits = model(img,text)

      print("logits"+str(torch.sigmoid(logits)))
      pred = torch.round(torch.sigmoid(logits)).detach().cpu().numpy()
      list.extend(pred)
  return list

In [None]:
json_file = read_json('/content/drive/MyDrive/MMM CONFERENCE/data',name='test')
json_dict = predict(model,device,test_file_path,image_processor)
with open(save_path, 'w') as f:
    json.dump(json_dict, f, indent=4)