In [1]:
import torch
from torch.utils.data import Dataset,DataLoader
import numpy as np
import os
import copy
import cv2
import json
import transformers
from transformers import AutoImageProcessor, ViTModel
from PIL import Image, ImageMath
from transformers import AutoModel
import torch.nn as nn
from transformers import AutoTokenizer, BertModel, BertTokenizer
from torch.optim import Adam
from tqdm import tqdm
import torch.nn.functional as F
from ast import literal_eval
from torchvision.models import resnet50,SwinTransformer
import torchvision
from transformers import AutoFeatureExtractor, ResNetForImageClassification
from transformers import AutoImageProcessor, AutoModelForImageClassification
import requests
from io import BytesIO
from transformers import RobertaTokenizer, RobertaModel
from focal_loss.focal_loss import FocalLoss
from sklearn.metrics import classification_report
# from focal_loss_pytorch.focal_loss_pytorch.focal_loss import BinaryFocalLoss
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

  from .autonotebook import tqdm as notebook_tqdm


'cuda'

In [2]:
version = "T_W_nondropout_FOCAL_2_03_train_VIT_BERT"

In [3]:
FOLDER = "/home/server-ailab-12gb/DUC-MMM"
data_dir = '/home/server-ailab-12gb/DUC-MMM/data/Musti'
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')
train_img_dir = os.path.join(train_dir, 'img')
train_file_path = os.path.join(train_dir, 'train.json')

test_img_dir = os.path.join(test_dir, 'img')
test_file_path = os.path.join(test_dir, 'test.json')
history_save_folder = os.path.join(FOLDER, 'history_save')
checkpoint_dir = os.path.join(FOLDER, 'checkpoint')
checkpoint_directory = os.path.join(checkpoint_dir, version)
CHECKPOINT_EXTENSION = '.pt'

predict_dir = os.path.join(FOLDER, "predict")
predict_path = os.path.join(predict_dir, version + '.json')
report_dir = os.path.join(FOLDER, "report")
report_path = os.path.join(report_dir, version)

report_dir = os.path.join(FOLDER, "report")
report_path = os.path.join(report_dir, version)

train_index_checkpoint = ''
test_index_checkpoint = ''
lr = 0.001
epoch = 0
loss = 10000

In [4]:
def save_checkpoint(checkpoint_directory, epoch, model, optimizer, LOSS, checkpoint_name=None):
    """
    The checkpoint will be saved in `checkpoint_directory` with name `checkpoint_name`.
    If `checkpoint_name` is None, the checkpoint will be saved with name `next_checkpoint_name_id + epoch`.
    """
    if checkpoint_directory is not None:
        if checkpoint_name is None:
            checkpoint_name = f'{epoch}{CHECKPOINT_EXTENSION}'

        path = os.path.join(checkpoint_directory, checkpoint_name)
        print("path: "+str(path))
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': LOSS,
                    }, path)
def load_checkpoint(PATH, model, optimizer):
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    return epoch, model, optimizer, loss

def save_report(pred_list, target_list, report_path):
    report = classification_report(target_list, pred_list, target_names=['NO', 'YES'], digits=6)
    with open(report_path, 'w') as f:
        f.write(report)
        
def read_json(parient_dir, name=None):
    # Opening JSON file
    if name == None:
        path = parient_dir
    else:
        path = os.path.join(parient_dir, str(name) + ".json")
    with open(path, 'r') as openfile:

        # Reading from json file
        json_object = json.load(openfile)
    return json_object
def write_json(dict, parient_dir, name=None):
    # Serializing json
    json_object = json.dumps(dict, indent=4)
    if name != None:
        path = os.path.join(parient_dir, str(name)+".json")
    else:
        path = parient_dir
    # Writing to sample.json
    with open(path, "w") as outfile:
        outfile.write(json_object)
# img_dir = '/content/drive/MyDrive/HaruLab/Dataset/Musti/test/img'
# json_file =
# values = json_file['pairs']

def load_image_from_url(url):
    try:
        response = requests.get(url)
        response.raise_for_status()

        # Open the image using Pillow
        image = Image.open(BytesIO(response.content))
        return image
    except requests.exceptions.RequestException as e:
        print(f"Error loading image from URL: {e}")
        return None
def split_filename(filepath):
    # Split the filepath into the directory and file parts
    directory, full_filename = os.path.split(filepath)

    # Split the file part into the base name and extension
    base_name, file_extension = os.path.splitext(full_filename)

    return directory, base_name, file_extension

## . Dataset&DataLoader

In [5]:
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
# tokenizer = AutoModel.from_pretrained("google/mt5-base")
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
vit_encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
resnet_encoder = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
resnet_encoder = nn.Sequential(*list(resnet_encoder.children())[:-1])
swintransformer = AutoModelForImageClassification.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
swintransformer = nn.Sequential(*list(swintransformer.children())[:-1])
# text_encoder = BertModel.from_pretrained("google/mt5-base").to(device)
text_encoder = BertModel.from_pretrained('bert-base-multilingual-cased')
encoder_layer = text_encoder.encoder
text_embedding= text_encoder.embeddings
pooling_layer= text_encoder.pooler

In [6]:
class Dataset_MMM(Dataset):
    def __init__(self, file_path,image_dir ,image_processor,tokenizer):
        super(Dataset_MMM, self).__init__()
        self.image_dir = image_dir
        data = read_json(file_path)
        if isinstance(data, dict):
            data = data['pairs']
        print(self.image_dir)
        self.data = []
        for obj in data:
          if os.path.exists(os.path.join(self.image_dir, obj['image'].rstrip())):
            self.data.append(obj)
          else:
            _, base_name, file_extension = split_filename(obj['image'].rstrip())
            if os.path.exists(os.path.join(self.image_dir, base_name + file_extension)):
              obj['image'] = base_name + file_extension
              self.data.append(obj)
            else:
              img = load_image_from_url(obj['image'].rstrip())
              obj['image'] = base_name + file_extension
              img.save(os.path.join(self.image_dir, base_name + file_extension))
              
              self.data.append(obj)
              print(base_name + file_extension)
        print("Len: " +str(len(self.data)))
        self.image_processor = image_processor
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir,self.data[index]['image'])
        text = self.data[index]['text']
        target = self.data[index]['subtask1_label']
        if target == 'NO':
          target = 0
        else:
          target = 1
        # process_image
        img = Image.open(img_path)
        if img.mode == 'L':
          img = img.convert('RGB')
        img = self.image_processor(images = img, return_tensors="pt",padding=True)
        # process_text
        text = self.tokenizer(text,padding="max_length", max_length=197,truncation=True, return_tensors="pt")

        return img, text, target

## Model

### Image

In [7]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
early_stopper = EarlyStopper(patience=5, min_delta=0)

### Text

### CrossModel

In [8]:
class CrossModel(nn.Module):
  def __init__(self,image_encoder,text_encoder,text_embedding, pooling_layer):
    super(CrossModel,self).__init__()
    #image_encoder
    self.vit_encoder = image_encoder[0]
    self.resnet_encoder = image_encoder[1]
    self.swintransformer_encoder = image_encoder[2]

    #text_encoder & text embedding
    self.text_encoder = text_encoder
    self.text_embedding = text_embedding
    self.pooling_layer = pooling_layer

    # layernorm
    self.layernorm_197 = nn.LayerNorm(197)
    self.layernorm_768 = nn.LayerNorm(768)

    # self.layernorm_1024 = nn.LayerNorm(1024)
    #linear
    self.linear = nn.Linear(2,1)

    # projection
    # self.resnet_projection_2 = nn.Linear(2048,197)
    self.resnet_projection_2 = nn.Linear(2048,197)
    self.resnet_projection_3 = nn.Linear(49,768)
    
    # Dropout
    self.dropout = nn.Dropout(p=0.2)

    self.swintransformer_projection_1 = nn.Linear(49,197)
    self.final_projection = nn.Linear(197,1)
    self.voting = nn.Linear(3,1)

  def forward(self,img,text):
    batch_size = img.size(0)
    #text_features
    text_embedding = self.text_embedding(text) # torch.Size([16, 197, 1024])

    #image_features
    image_features_0 = self.vit_encoder(img).last_hidden_state #torch.Size([16, 197, 768])
    image_features_1 = self.resnet_encoder(img).last_hidden_state #torch.Size([16, 512, 7, 7])
    image_features_2 = self.swintransformer_encoder(img).last_hidden_state #torch.Size([16, 49, 768])
    dim = image_features_1.size(1)
    #projection_1
    image_features_1 = image_features_1.reshape(batch_size, dim,-1) #torch.Size([16, 512, 49])
    image_features_1 = image_features_1.permute(0,2,1) #torch.Size([16, 49, 512])
    image_features_1 = self.resnet_projection_2(image_features_1) #torch.Size([16, 49, 197])
    image_features_1 = self.layernorm_197(image_features_1)
    image_features_1 = image_features_1.permute(0,2,1) #torch.Size([16, 197, 49])
    image_features_1 = self.resnet_projection_3(image_features_1) #torch.Size([16, 197, 768])
    image_features_1 = self.layernorm_768(image_features_1)
    # image_features_1 = self.dropout(image_features_1)
    #projection_2
    image_features_2 = image_features_2.permute(0,2,1) #torch.Size([16, 768, 49])
    image_features_2 = self.swintransformer_projection_1(image_features_2) #torch.Size([16, 768, 197])
    image_features_2 = self.layernorm_197(image_features_2)
    image_features_2 = image_features_2.permute(0,2,1) #torch.Size([16, 197, 768])

    final_features_0 = self.layernorm_768(self.linear(torch.cat([image_features_0.unsqueeze(-1) , text_embedding.unsqueeze(-1)], dim =-1)).squeeze(-1) + text_embedding)
    final_features_1 = self.layernorm_768(self.linear(torch.cat([image_features_1.unsqueeze(-1) , text_embedding.unsqueeze(-1)], dim =-1)).squeeze(-1) + text_embedding)
    final_features_2 = self.layernorm_768(self.linear(torch.cat([image_features_2.unsqueeze(-1) , text_embedding.unsqueeze(-1)], dim =-1)).squeeze(-1) + text_embedding)

    final_features_0 = self.pooling_layer(self.text_encoder(final_features_0).last_hidden_state)
    final_features_1 = self.pooling_layer(self.text_encoder(final_features_1).last_hidden_state)
    final_features_2 = self.pooling_layer(self.text_encoder(final_features_2).last_hidden_state)


    image_features_0 = self.final_projection(image_features_0.permute(0,2,1)).squeeze(-1)
    image_features_1 = self.final_projection(image_features_1.permute(0,2,1)).squeeze(-1)
    image_features_2 = self.final_projection(image_features_2.permute(0,2,1)).squeeze(-1)

    logits_0 = F.cosine_similarity(final_features_0,image_features_0, dim =-1)
    logits_1 = F.cosine_similarity(final_features_1,image_features_1, dim =-1)
    logits_2 = F.cosine_similarity(final_features_2,image_features_2, dim =-1)
    logits = self.voting(torch.cat([logits_0.unsqueeze(-1), logits_1.unsqueeze(-1), logits_2.unsqueeze(-1)], dim =-1)).squeeze(-1)
    return logits

## Train

In [9]:
# dataset_train = Dataset_MMM(train_file_path, train_img_dir,image_processor,tokenizer,is_url=is_url)
# dataloader_train = DataLoader(dataset_train,batch_size=20,drop_last=True, shuffle=True)
# len(dataset_train)

In [10]:
def train(train_loader, model,device, epochs=10, total_iterations_limit=None, optimizer=None, cur_epoch=0, best_valid_loss=10000):
    # model = model.to(device)
    # optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    history = {"loss": [],
                "acc": []}
    

    total_iterations = 0
    loss_fn = FocalLoss(gamma=2, weights=torch.tensor(0.3))
    for epoch in range(cur_epoch + 1, epochs):
        model.train()
        avg_acc = 0
        acc_sum = 0
        avg_loss = 0
        loss_sum = 0
        num_iterations = 0
        pred_list = []
        target_list = []

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            img = torch.squeeze(data[0]['pixel_values'],1)
            text = torch.squeeze(data[1]['input_ids'],1)
            target = torch.squeeze(data[2])
            # print(text.shape)
            img = img.to(device)
            text = text.to(device)
            target = target.to(device).to(torch.int64)
            optimizer.zero_grad()
            
            logits = model(img,text)
            loss = loss_fn(torch.sigmoid(logits), target)
            loss_sum += loss.item()

            pred = torch.round(torch.sigmoid(logits.detach()))
            correct_pred = (target == pred).float()
            acc_sum += (correct_pred.sum() / len(correct_pred)).cpu().item()
            

            pred_list.extend(pred.cpu().numpy())
            target_list.extend(target.cpu().numpy())


            data_iterator.set_postfix(loss=loss_sum / num_iterations)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

        avg_loss = loss_sum / num_iterations
        avg_acc = acc_sum / num_iterations

        if avg_loss < best_valid_loss:
            best_valid_loss = avg_loss
            save_checkpoint(checkpoint_directory, epoch, model, optimizer, avg_loss)
            save_report(pred_list=pred_list, target_list=target_list, report_path= os.path.join(report_path, str(epoch) + '.txt'))
        history['loss'].append(avg_loss)
        history['acc'].append(avg_acc)
        print("Accuracy: "+str(avg_acc))
        if early_stopper.early_stop(avg_loss):
                print("Early stop at epoch: "+str(epoch) + " with valid loss: "+str(avg_loss))
                break
    return history

def predict(model,device,dataloader,image_processor):
  list = []
  model.eval()

  data_iterator = tqdm(dataloader, desc=f'Evaluate')
  with torch.no_grad():
    for data in data_iterator:
      img = torch.squeeze(data[0]['pixel_values'],1)
      text = torch.squeeze(data[1]['input_ids'],1)
      target = torch.squeeze(data[2])
      # print(text.shape)
      img = img.to(device)
      text = text.to(device)
      target = target.to(device).to(torch.int64)
      
      logits = model(img,text)

      print("logits"+str(torch.sigmoid(logits)))
      pred = torch.round(torch.sigmoid(logits)).detach().cpu().numpy()
      list.extend(pred)
  return list

In [11]:
dataset_train = Dataset_MMM(train_file_path, train_img_dir,image_processor,tokenizer)
dataloader_train = DataLoader(dataset_train,batch_size=20,drop_last=True, shuffle=True)

/home/server-ailab-12gb/DUC-MMM/data/Musti/train/img
Len: 2374


In [12]:
img_encoder =[vit_encoder,resnet_encoder,swintransformer]
model = CrossModel(img_encoder,encoder_layer,text_embedding, pooling_layer)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

if len(train_index_checkpoint) > 0:
    PATH = os.path.join(checkpoint_directory, train_index_checkpoint + CHECKPOINT_EXTENSION)
    print("Load checkpoint: "+str(PATH))
    epoch, model, optimizer, loss = load_checkpoint(PATH, model, optimizer)

history_losses = train(dataloader_train,model,device ,epochs=100, optimizer=optimizer, cur_epoch=epoch, best_valid_loss=loss)
write_json(history_losses, parient_dir=history_save_folder, name=version)

Epoch 1: 100%|██████████| 118/118 [01:51<00:00,  1.06it/s, loss=0.144]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/1.pt
Accuracy: 0.7326271287718061


Epoch 2: 100%|██████████| 118/118 [01:52<00:00,  1.05it/s, loss=0.13] 


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/2.pt
Accuracy: 0.7601695035473776


Epoch 3: 100%|██████████| 118/118 [01:52<00:00,  1.04it/s, loss=0.122]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/3.pt
Accuracy: 0.7872881454936529


Epoch 4: 100%|██████████| 118/118 [01:47<00:00,  1.10it/s, loss=0.114]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/4.pt
Accuracy: 0.7927966228986191


Epoch 5: 100%|██████████| 118/118 [01:45<00:00,  1.12it/s, loss=0.108]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/5.pt
Accuracy: 0.8093220495571525


Epoch 6: 100%|██████████| 118/118 [01:51<00:00,  1.06it/s, loss=0.108]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/6.pt
Accuracy: 0.7961864501742993


Epoch 7: 100%|██████████| 118/118 [01:51<00:00,  1.05it/s, loss=0.107]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/7.pt
Accuracy: 0.8088983186220718


Epoch 8: 100%|██████████| 118/118 [01:49<00:00,  1.08it/s, loss=0.102]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/8.pt
Accuracy: 0.8093220490520283


Epoch 9: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.096] 


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/9.pt
Accuracy: 0.8216101815134792


Epoch 10: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0953]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/10.pt
Accuracy: 0.8254237432601088


Epoch 11: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.091] 


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/11.pt
Accuracy: 0.833474590111587


Epoch 12: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0894]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/12.pt
Accuracy: 0.8309322186445786


Epoch 13: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0875]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/13.pt
Accuracy: 0.8368644209231361


Epoch 14: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0868]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/14.pt
Accuracy: 0.8368644214282601


Epoch 15: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0845]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/15.pt
Accuracy: 0.8385593411275896


Epoch 16: 100%|██████████| 118/118 [01:45<00:00,  1.12it/s, loss=0.0837]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/16.pt
Accuracy: 0.8389830665063049


Epoch 17: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0822]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/17.pt
Accuracy: 0.8470339156308416


Epoch 18: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0811]


path: /home/server-ailab-12gb/DUC-MMM/checkpoint/T_W_nondropout_FOCAL_2_03_train_VIT_BERT/18.pt
Accuracy: 0.8500000162649963


Epoch 19: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0826]


Accuracy: 0.8394067939055168


Epoch 20: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0819]


Accuracy: 0.8411017105741015


Epoch 21: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0842]


Accuracy: 0.8377118797625526


Epoch 22: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0882]


Accuracy: 0.8300847633410309


Epoch 23: 100%|██████████| 118/118 [01:44<00:00,  1.13it/s, loss=0.0864]


Accuracy: 0.8372881508479684
Early stop at epoch: 23 with valid loss: 0.08644674266925302


FileNotFoundError: [Errno 2] No such file or directory: '/home/server-ailab-12gb/DUC-MMM/history_save/T_W_nondropout_FOCAL_2_03_train_VIT_BERT.json'

In [13]:
dataset_test = Dataset_MMM(test_file_path, test_img_dir,image_processor,tokenizer)
dataloader_test = DataLoader(dataset_test,batch_size=16,drop_last=False, shuffle=False)

/home/server-ailab-12gb/DUC-MMM/data/Musti/test/img
Len: 814


In [14]:
if len(test_index_checkpoint) > 0:
    img_encoder =[vit_encoder,resnet_encoder,swintransformer]
    model = CrossModel(img_encoder,encoder_layer,text_embedding, pooling_layer)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    PATH = os.path.join(checkpoint_directory, test_index_checkpoint + CHECKPOINT_EXTENSION)
    epoch, model, optimizer, loss = load_checkpoint(PATH, model, optimizer)
    print("Load checkpoint: "+str(PATH))
list = predict(model,device,dataloader_test,image_processor)
json_dict = read_json(test_file_path)
for i in range(len(json_dict['pairs'])):
    if list[i] == 0:
        json_dict['pairs'][i]['subtask1_label'] = 'NO'
    else:
        json_dict['pairs'][i]['subtask1_label'] = 'YES'
with open(predict_path, 'w') as f:
    json.dump(json_dict, f, indent=4)

Evaluate:   2%|▏         | 1/51 [00:00<00:17,  2.78it/s]

logitstensor([0.1730, 0.1602, 0.4063, 0.4766, 0.1863, 0.2823, 0.1747, 0.1730, 0.1722,
        0.1722, 0.5204, 0.4470, 0.1578, 0.3973, 0.1710, 0.5559],
       device='cuda:0')


Evaluate:   4%|▍         | 2/51 [00:00<00:16,  2.95it/s]

logitstensor([0.6488, 0.1602, 0.1649, 0.4880, 0.3651, 0.1724, 0.5466, 0.1609, 0.1806,
        0.1610, 0.1610, 0.6052, 0.5515, 0.1722, 0.5559, 0.1793],
       device='cuda:0')


Evaluate:   6%|▌         | 3/51 [00:01<00:15,  3.00it/s]

logitstensor([0.1722, 0.4512, 0.3224, 0.5742, 0.5186, 0.1751, 0.3857, 0.6581, 0.6629,
        0.5058, 0.1662, 0.4366, 0.1781, 0.6630, 0.1657, 0.6630],
       device='cuda:0')


Evaluate:   8%|▊         | 4/51 [00:01<00:15,  3.00it/s]

logitstensor([0.5579, 0.2617, 0.5146, 0.1641, 0.5579, 0.2022, 0.1662, 0.2617, 0.4176,
        0.1720, 0.4195, 0.1781, 0.6421, 0.4754, 0.2075, 0.1610],
       device='cuda:0')


Evaluate:  10%|▉         | 5/51 [00:01<00:15,  3.06it/s]

logitstensor([0.1602, 0.4960, 0.2075, 0.6052, 0.2617, 0.1781, 0.1610, 0.1917, 0.1795,
        0.5086, 0.2047, 0.5226, 0.1652, 0.4470, 0.3763, 0.1704],
       device='cuda:0')


Evaluate:  12%|█▏        | 6/51 [00:01<00:14,  3.05it/s]

logitstensor([0.4487, 0.6284, 0.6230, 0.4487, 0.5204, 0.1616, 0.1893, 0.1720, 0.5050,
        0.1662, 0.3934, 0.5466, 0.4886, 0.1863, 0.6254, 0.3847],
       device='cuda:0')


Evaluate:  14%|█▎        | 7/51 [00:02<00:14,  3.05it/s]

logitstensor([0.1573, 0.5559, 0.1645, 0.2617, 0.1924, 0.4826, 0.1748, 0.1893, 0.3628,
        0.4073, 0.2715, 0.1711, 0.1728, 0.1722, 0.1661, 0.6311],
       device='cuda:0')


Evaluate:  16%|█▌        | 8/51 [00:02<00:13,  3.09it/s]

logitstensor([0.1796, 0.1641, 0.1706, 0.5255, 0.1617, 0.1641, 0.2047, 0.1749, 0.5606,
        0.4964, 0.1579, 0.4855, 0.5374, 0.1748, 0.1662, 0.1863],
       device='cuda:0')


Evaluate:  18%|█▊        | 9/51 [00:02<00:13,  3.09it/s]

logitstensor([0.5721, 0.1673, 0.4524, 0.1645, 0.3651, 0.5255, 0.1610, 0.5945, 0.5146,
        0.1720, 0.3934, 0.6311, 0.5805, 0.3417, 0.5338, 0.1683],
       device='cuda:0')


Evaluate:  20%|█▉        | 10/51 [00:03<00:13,  2.95it/s]

logitstensor([0.2709, 0.1744, 0.5891, 0.1610, 0.6521, 0.1839, 0.1964, 0.1654, 0.4964,
        0.1666, 0.1673, 0.1704, 0.1795, 0.1697, 0.3628, 0.5805],
       device='cuda:0')


Evaluate:  22%|██▏       | 11/51 [00:03<00:13,  3.02it/s]

logitstensor([0.4290, 0.6513, 0.1796, 0.5805, 0.1616, 0.5883, 0.1573, 0.6284, 0.1612,
        0.1795, 0.1874, 0.4093, 0.4628, 0.1628, 0.4377, 0.5898],
       device='cuda:0')


Evaluate:  24%|██▎       | 12/51 [00:03<00:13,  2.99it/s]

logitstensor([0.4195, 0.6650, 0.5058, 0.6284, 0.6521, 0.6412, 0.2772, 0.1612, 0.4963,
        0.6551, 0.4706, 0.4589, 0.1646, 0.4524, 0.3847, 0.1662],
       device='cuda:0')


Evaluate:  25%|██▌       | 13/51 [00:04<00:13,  2.90it/s]

logitstensor([0.6551, 0.1727, 0.2022, 0.1611, 0.1646, 0.5898, 0.4176, 0.5515, 0.5147,
        0.2014, 0.1696, 0.1833, 0.1771, 0.1858, 0.4392, 0.1851],
       device='cuda:0')


Evaluate:  27%|██▋       | 14/51 [00:04<00:12,  2.93it/s]

logitstensor([0.2571, 0.1671, 0.2183, 0.1658, 0.3864, 0.1702, 0.1626, 0.1640, 0.5076,
        0.2208, 0.1709, 0.2208, 0.2823, 0.1728, 0.2571, 0.1609],
       device='cuda:0')


Evaluate:  29%|██▉       | 15/51 [00:05<00:12,  2.98it/s]

logitstensor([0.2067, 0.1803, 0.2772, 0.1696, 0.1855, 0.4880, 0.1579, 0.1663, 0.6322,
        0.4880, 0.4656, 0.2506, 0.5147, 0.5395, 0.3763, 0.4392],
       device='cuda:0')


Evaluate:  31%|███▏      | 16/51 [00:05<00:11,  2.93it/s]

logitstensor([0.1627, 0.1923, 0.4230, 0.1631, 0.6561, 0.3973, 0.4262, 0.6402, 0.4022,
        0.1611, 0.2047, 0.1662, 0.5481, 0.3763, 0.2802, 0.1572],
       device='cuda:0')


Evaluate:  33%|███▎      | 17/51 [00:05<00:11,  2.99it/s]

logitstensor([0.4290, 0.1645, 0.2014, 0.4392, 0.4524, 0.1921, 0.4073, 0.1965, 0.6380,
        0.3628, 0.3357, 0.2571, 0.1663, 0.5387, 0.5147, 0.6610],
       device='cuda:0')


Evaluate:  35%|███▌      | 18/51 [00:06<00:10,  3.02it/s]

logitstensor([0.1732, 0.4125, 0.1724, 0.3894, 0.6322, 0.1924, 0.4072, 0.2174, 0.5174,
        0.1709, 0.5086, 0.5147, 0.4290, 0.1930, 0.1648, 0.5147],
       device='cuda:0')


Evaluate:  37%|███▋      | 19/51 [00:06<00:10,  2.98it/s]

logitstensor([0.1732, 0.5884, 0.2007, 0.1579, 0.1825, 0.6506, 0.1611, 0.5147, 0.4779,
        0.1924, 0.1722, 0.6322, 0.2919, 0.1921, 0.1657, 0.4470],
       device='cuda:0')


Evaluate:  39%|███▉      | 20/51 [00:06<00:10,  3.02it/s]

logitstensor([0.2145, 0.4290, 0.4397, 0.6506, 0.1669, 0.1683, 0.2081, 0.4371, 0.6380,
        0.1701, 0.3732, 0.2183, 0.5374, 0.4963, 0.6523, 0.2802],
       device='cuda:0')


Evaluate:  41%|████      | 21/51 [00:06<00:09,  3.06it/s]

logitstensor([0.2208, 0.4886, 0.5374, 0.3680, 0.6506, 0.6506, 0.2183, 0.1858, 0.1578,
        0.4176, 0.4656, 0.4429, 0.4072, 0.6583, 0.1924, 0.1790],
       device='cuda:0')


Evaluate:  43%|████▎     | 22/51 [00:07<00:09,  2.98it/s]

logitstensor([0.1763, 0.1579, 0.4366, 0.5076, 0.6523, 0.4397, 0.4063, 0.4886, 0.5884,
        0.1609, 0.1724, 0.1611, 0.1696, 0.1803, 0.5466, 0.4963],
       device='cuda:0')


Evaluate:  45%|████▌     | 23/51 [00:07<00:09,  3.03it/s]

logitstensor([0.4397, 0.1649, 0.2047, 0.4392, 0.1649, 0.2642, 0.2823, 0.1631, 0.2265,
        0.1627, 0.6642, 0.6550, 0.5076, 0.1649, 0.2802, 0.1777],
       device='cuda:0')


Evaluate:  47%|████▋     | 24/51 [00:07<00:08,  3.06it/s]

logitstensor([0.1753, 0.4024, 0.5884, 0.3845, 0.5590, 0.2491, 0.4208, 0.1680, 0.1675,
        0.5147, 0.2823, 0.3680, 0.2265, 0.1671, 0.1666, 0.1680],
       device='cuda:0')


Evaluate:  49%|████▉     | 25/51 [00:08<00:08,  2.95it/s]

logitstensor([0.1645, 0.1572, 0.3934, 0.1709, 0.2183, 0.1645, 0.1658, 0.2067, 0.1662,
        0.1645, 0.4066, 0.1635, 0.5466, 0.6523, 0.1572, 0.2047],
       device='cuda:0')


Evaluate:  51%|█████     | 26/51 [00:08<00:08,  2.96it/s]

logitstensor([0.2506, 0.5590, 0.3899, 0.4995, 0.3799, 0.3899, 0.1902, 0.4995, 0.5590,
        0.1902, 0.1704, 0.4468, 0.4995, 0.4938, 0.4938, 0.1704],
       device='cuda:0')


Evaluate:  53%|█████▎    | 27/51 [00:09<00:07,  3.01it/s]

logitstensor([0.5122, 0.4425, 0.4869, 0.5901, 0.2058, 0.2058, 0.5215, 0.1704, 0.4984,
        0.3338, 0.1704, 0.1742, 0.1868, 0.1648, 0.4984, 0.3808],
       device='cuda:0')


Evaluate:  55%|█████▍    | 28/51 [00:09<00:07,  3.04it/s]

logitstensor([0.2313, 0.4865, 0.3817, 0.1802, 0.3216, 0.4147, 0.3447, 0.3899, 0.2362,
        0.3524, 0.3338, 0.4066, 0.4392, 0.4984, 0.4938, 0.2372],
       device='cuda:0')


Evaluate:  57%|█████▋    | 29/51 [00:09<00:07,  3.07it/s]

logitstensor([0.4429, 0.4855, 0.1704, 0.4804, 0.2577, 0.4379, 0.4754, 0.4995, 0.4982,
        0.1603, 0.1758, 0.4546, 0.1802, 0.2058, 0.2187, 0.4546],
       device='cuda:0')


Evaluate:  59%|█████▉    | 30/51 [00:09<00:06,  3.08it/s]

logitstensor([0.2942, 0.3522, 0.4855, 0.1758, 0.4031, 0.1704, 0.4865, 0.1638, 0.1797,
        0.3625, 0.1704, 0.3805, 0.1902, 0.1603, 0.1902, 0.6426],
       device='cuda:0')


Evaluate:  61%|██████    | 31/51 [00:10<00:06,  2.98it/s]

logitstensor([0.3817, 0.1902, 0.3860, 0.4031, 0.3978, 0.5788, 0.3625, 0.1772, 0.2130,
        0.4468, 0.5590, 0.4379, 0.1797, 0.4392, 0.5115, 0.4865],
       device='cuda:0')


Evaluate:  63%|██████▎   | 32/51 [00:10<00:06,  3.00it/s]

logitstensor([0.3978, 0.4147, 0.1746, 0.4147, 0.1802, 0.2000, 0.5115, 0.5599, 0.4022,
        0.3338, 0.1676, 0.4982, 0.2058, 0.6551, 0.1902, 0.4513],
       device='cuda:0')


Evaluate:  65%|██████▍   | 33/51 [00:10<00:05,  3.04it/s]

logitstensor([0.2058, 0.4487, 0.5249, 0.5371, 0.1704, 0.1771, 0.1753, 0.4147, 0.2058,
        0.4855, 0.1573, 0.1632, 0.3447, 0.1758, 0.1758, 0.4984],
       device='cuda:0')


Evaluate:  67%|██████▋   | 34/51 [00:11<00:05,  2.98it/s]

logitstensor([0.1638, 0.3917, 0.2022, 0.3338, 0.1704, 0.4661, 0.1632, 0.4546, 0.5312,
        0.4379, 0.5115, 0.2194, 0.2680, 0.2783, 0.3805, 0.4147],
       device='cuda:0')


Evaluate:  69%|██████▊   | 35/51 [00:11<00:05,  3.03it/s]

logitstensor([0.1862, 0.6461, 0.1603, 0.4030, 0.5312, 0.4030, 0.4468, 0.4995, 0.2058,
        0.2577, 0.5788, 0.1638, 0.1797, 0.4995, 0.2060, 0.1746],
       device='cuda:0')


Evaluate:  71%|███████   | 36/51 [00:11<00:04,  3.07it/s]

logitstensor([0.4546, 0.2139, 0.5115, 0.3978, 0.4995, 0.4984, 0.1814, 0.4982, 0.2680,
        0.1645, 0.2598, 0.1573, 0.5255, 0.2945, 0.1603, 0.4984],
       device='cuda:0')


Evaluate:  73%|███████▎  | 37/51 [00:12<00:04,  3.08it/s]

logitstensor([0.3338, 0.4085, 0.3070, 0.4147, 0.1802, 0.2139, 0.4995, 0.1797, 0.2997,
        0.4147, 0.4379, 0.3805, 0.4984, 0.2257, 0.4951, 0.1638],
       device='cuda:0')


Evaluate:  75%|███████▍  | 38/51 [00:12<00:04,  3.12it/s]

logitstensor([0.1852, 0.1603, 0.1638, 0.3338, 0.1758, 0.2680, 0.1666, 0.5788, 0.1902,
        0.2313, 0.1638, 0.2491, 0.6183, 0.3166, 0.1764, 0.3805],
       device='cuda:0')


Evaluate:  76%|███████▋  | 39/51 [00:12<00:03,  3.15it/s]

logitstensor([0.4379, 0.4995, 0.1603, 0.4661, 0.3808, 0.2598, 0.6614, 0.5631, 0.3619,
        0.5788, 0.4875, 0.3114, 0.4779, 0.1985, 0.6426, 0.1679],
       device='cuda:0')


Evaluate:  78%|███████▊  | 40/51 [00:13<00:03,  3.16it/s]

logitstensor([0.1998, 0.4463, 0.4628, 0.5864, 0.2139, 0.4430, 0.6056, 0.2390, 0.2130,
        0.6056, 0.3619, 0.4290, 0.2269, 0.6610, 0.5940, 0.4943],
       device='cuda:0')


Evaluate:  80%|████████  | 41/51 [00:13<00:03,  3.16it/s]

logitstensor([0.1758, 0.1985, 0.3857, 0.4463, 0.5794, 0.5641, 0.4284, 0.2242, 0.6387,
        0.5249, 0.4620, 0.3433, 0.2661, 0.4372, 0.3668, 0.4941],
       device='cuda:0')


Evaluate:  82%|████████▏ | 42/51 [00:13<00:02,  3.18it/s]

logitstensor([0.3845, 0.1655, 0.1965, 0.6056, 0.3334, 0.6003, 0.2028, 0.4262, 0.5641,
        0.1723, 0.6354, 0.2081, 0.1641, 0.5864, 0.2020, 0.1612],
       device='cuda:0')


Evaluate:  84%|████████▍ | 43/51 [00:14<00:02,  3.18it/s]

logitstensor([0.4392, 0.2097, 0.2097, 0.1655, 0.2019, 0.6324, 0.6341, 0.2548, 0.6598,
        0.1665, 0.4502, 0.5197, 0.5639, 0.2873, 0.3922, 0.3922],
       device='cuda:0')


Evaluate:  86%|████████▋ | 44/51 [00:14<00:02,  3.06it/s]

logitstensor([0.2291, 0.4658, 0.3354, 0.3334, 0.5077, 0.1965, 0.4484, 0.4963, 0.1610,
        0.4524, 0.4208, 0.2055, 0.1671, 0.5480, 0.3679, 0.3318],
       device='cuda:0')


Evaluate:  88%|████████▊ | 45/51 [00:14<00:01,  3.08it/s]

logitstensor([0.6056, 0.3992, 0.3134, 0.6522, 0.3715, 0.1702, 0.6610, 0.5480, 0.5249,
        0.6354, 0.2534, 0.4943, 0.6387, 0.2838, 0.1744, 0.4943],
       device='cuda:0')


Evaluate:  90%|█████████ | 46/51 [00:15<00:01,  3.13it/s]

logitstensor([0.4858, 0.1637, 0.6599, 0.2078, 0.5061, 0.1998, 0.2144, 0.3313, 0.3679,
        0.2784, 0.3775, 0.3737, 0.3111, 0.3212, 0.4378, 0.1641],
       device='cuda:0')


Evaluate:  92%|█████████▏| 47/51 [00:15<00:01,  3.16it/s]

logitstensor([0.4487, 0.6598, 0.1681, 0.5963, 0.4875, 0.5579, 0.4093, 0.3845, 0.4943,
        0.3495, 0.1612, 0.1673, 0.4910, 0.5579, 0.2130, 0.3807],
       device='cuda:0')


Evaluate:  94%|█████████▍| 48/51 [00:15<00:00,  3.18it/s]

logitstensor([0.6475, 0.4779, 0.1579, 0.5864, 0.2642, 0.3884, 0.1581, 0.1729, 0.5076,
        0.5455, 0.6598, 0.3468, 0.1627, 0.4875, 0.5158, 0.2930],
       device='cuda:0')


Evaluate:  96%|█████████▌| 49/51 [00:16<00:00,  3.19it/s]

logitstensor([0.2358, 0.3433, 0.6052, 0.2102, 0.2014, 0.5963, 0.5963, 0.1723, 0.4463,
        0.5725, 0.4658, 0.4502, 0.1790, 0.1581, 0.6610, 0.4661],
       device='cuda:0')


Evaluate:  98%|█████████▊| 50/51 [00:16<00:00,  3.17it/s]

logitstensor([0.5639, 0.1771, 0.1645, 0.4092, 0.5478, 0.5223, 0.3212, 0.2058, 0.5226,
        0.2097, 0.4607, 0.4941, 0.4148, 0.5115, 0.4144, 0.2357],
       device='cuda:0')


Evaluate: 100%|██████████| 51/51 [00:16<00:00,  3.06it/s]

logitstensor([0.1724, 0.1718, 0.1793, 0.4311, 0.5655, 0.3298, 0.3922, 0.1662, 0.1617,
        0.6331, 0.2885, 0.4502, 0.3055, 0.2608], device='cuda:0')





In [None]:
# img_encoder =[vit_encoder,resnet_encoder,swintransformer]
# model = CrossModel(img_encoder,encoder_layer,text_embedding, pooling_layer)

# lr = 0.001
# model = model.to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# epoch = 0
# loss = 10000
# # index_checkpoint = '3'
# # PATH = os.path.join(checkpoint_directory, index_checkpoint + CHECKPOINT_EXTENSION)
# # epoch, model, optimizer, loss = load_checkpoint(PATH, model, optimizer)
# # print("Load checkpoint: "+str(PATH))

# history_losses = train(dataloader_train,model,device ,epochs=100, optimizer=optimizer, cur_epoch=epoch, best_valid_loss=loss)

## Predict

In [None]:
# def predict(model,device,dataloader,image_processor):
#   list = []
#   model.eval()

#   data_iterator = tqdm(dataloader, desc=f'Epoch')
#   with torch.no_grad():
#     for data in data_iterator:
#       img = torch.squeeze(data[0]['pixel_values'],1)
#       text = torch.squeeze(data[1]['input_ids'],1)
#       target = torch.squeeze(data[2])
#       # print(text.shape)
#       img = img.to(device)
#       text = text.to(device)
#       target = target.to(device).to(torch.int64)
#       optimizer.zero_grad()
      
#       logits = model(img,text)

#       print("logits"+str(torch.sigmoid(logits)))
#       pred = torch.round(torch.sigmoid(logits)).detach().cpu().numpy()
#       list.extend(pred)
#   return list

In [None]:
# dataset_test = Dataset_MMM(test_file_path, test_img_dir,image_processor,tokenizer,is_url=True)
# dataloader_test = DataLoader(dataset_test,batch_size=16,drop_last=False, shuffle=False)

In [None]:
# img_encoder =[vit_encoder,resnet_encoder,swintransformer]
# model = CrossModel(img_encoder,encoder_layer,text_embedding, pooling_layer)

# lr=0.001
# model = model.to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# index_checkpoint = '19'
# PATH = os.path.join(checkpoint_directory, index_checkpoint + CHECKPOINT_EXTENSION)
# epoch, model, optimizer, loss = load_checkpoint(PATH, model, optimizer)
# print("Load checkpoint: "+str(PATH))

# list = predict(model,device,dataloader_test,image_processor)
# json_dict = read_json(test_file_path)
# for i in range(len(json_dict)):
#     if list[i] == 0:
#         json_dict[i]['subtask1_label'] = 'NO'
#     else:
#         json_dict[i]['subtask1_label'] = 'YES'
# with open(test_save_path, 'w') as f:
#     json.dump(json_dict, f, indent=4)

In [None]:
len(json_dict)

1