In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from PIL import Image
from torchmetrics.classification import BinaryAUROC
import pandas as pd
import os
import clip
import torch.nn as nn
import torch.nn.functional as F

import IPython.display
import matplotlib.pyplot as plt
from PIL import Image

from collections import OrderedDict

print("Torch version:", torch.__version__)


Torch version: 1.13.0+cu117


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
HOME = os.environ.get("HOME")
images_path = f"{HOME}/.cache/torch/mmf/data/datasets/hateful_memes/defaults/images/"
annotations_path = f"{HOME}/.cache/torch/mmf/data/datasets/hateful_memes/defaults/annotations/"

In [4]:
class HMDataset(Dataset):
    def __init__(self, images_path: str, annotation_path: str, image_transform=None, text_transform=None) -> None:
        self.images_path = images_path
        self.annotation_path = annotation_path
        self.image_transform = image_transform
        self.text_transform = text_transform
        assert self.annotation_path.endswith(".jsonl"), f"Invalid annotation file format. Format should be '.jsonl', not {self.annotation_path.split('.')[0]}"
        self.annotation: pd.DataFrame = pd.read_json(self.annotation_path, lines=True)      

    def __len__(self):
        return self.annotation.shape[0]

    def __getitem__(self, index):
        img_path = os.path.join(self.images_path.split('img')[0], self.annotation.loc[index,"img"])
        image = Image.open(img_path).convert("RGB")
        text = self.annotation.loc[index,"text"]
        label = self.annotation.loc[index,"label"]
        if self.image_transform:
            image = self.image_transform(image)
        if self.text_transform:
            text = self.text_transform(text)
        return image, text, torch.tensor(label)

In [5]:
clip.available_models()


['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [6]:
model, preprocess = clip.load("ViT-L/14@336px")
model.to(device).eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

100%|███████████████████████████████████████| 891M/891M [02:54<00:00, 5.35MiB/s]


Model parameters: 427,944,193
Input resolution: 336
Context length: 77
Vocab size: 49408


In [7]:
class HMMLP(nn.Module):

    def __init__(self, n_in=768*2, n_out=1, ) -> None:
        super().__init__()

        self.fc1 = nn.Linear(n_in, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 16)
        self.fc4 = nn.Linear(16, n_out)

    def forward(self, x):
        
        # x = self.fc1(x)
        x = F.gelu(self.fc1(x))
        x = F.gelu(self.fc2(x))
        x = F.gelu(self.fc3(x))
        x = self.fc4(x)

        return x

In [8]:
batch_size = 32

def text_preprocess(text):
    return clip.tokenize(text, truncate=True)

hm_train_dataset = HMDataset(images_path, f"{annotations_path}/train_v2.jsonl", image_transform=preprocess, text_transform=text_preprocess)
hm_test_dataset = HMDataset(images_path, f"{annotations_path}/test_unseen.jsonl", image_transform=preprocess, text_transform=text_preprocess)
hm_val_dataset = HMDataset(images_path, f"{annotations_path}/dev_unseen.jsonl", image_transform=preprocess, text_transform=text_preprocess)

train_dataloader = DataLoader(hm_train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(hm_test_dataset, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(hm_val_dataset, batch_size=batch_size, shuffle=False)


In [9]:
from tqdm import tqdm
net = HMMLP()

model_path = 'base_model.pt'
net = net.to(device)
criterion = nn.BCEWithLogitsLoss()

optimizer = torch.optim.AdamW(net.parameters(), lr=0.01)

epochs = 20
print_every = 50

for epoch in range(epochs):
    
    running_loss = 0
    for i, data in enumerate(tqdm(train_dataloader), 0):
        images, texts, labels = data
        images = images.to(device)
        texts = texts.to(device)
        labels = labels.float().squeeze().to(device)

        with torch.no_grad():
            images = model.encode_image(images) # input_dim: batch_size x 3 x H x W; output_dim: batch_size x 512
            texts = model.encode_text(texts.squeeze()) # input_dim: batch_size x 77; output_dim: batch_size x 512
        
        fused_images_texts = torch.hstack((images,texts))
        fused_images_texts.requires_grad_()
        fused_images_texts = fused_images_texts.float()

        optimizer.zero_grad()

        # Forward pass on the fused data

        output = net(fused_images_texts)

        loss = criterion(output.squeeze(), labels)

        # Compute gradient
        loss.backward()
        # Update weight
        optimizer.step()

        running_loss += loss.item()

        if i%print_every==(print_every - 1): 
            
            print(f"[Epoch {epoch + 1}, step {i+1:3d}] loss: {running_loss/print_every:.5f}")
            running_loss = 0.0
    ## Switch to eval mode
    torch.save(net.state_dict(), model_path)
    net.eval()
    running_loss = 0.0
    correct_preds = 0
    total_preds = 0

    preds_all_val = torch.tensor([]).cuda()
    labels_all_val = torch.tensor([]).cuda()
    for i, data in enumerate(val_dataloader, 0):
        
        images, texts, labels = data
        images = images.to(device)
        texts = texts.to(device)
        labels = labels.float().squeeze().to(device)

        with torch.no_grad():
            images = model.encode_image(images) # input_dim: batch_size x 3 x H x W; output_dim: batch_size x 512
            texts = model.encode_text(texts.squeeze()) # input_dim: batch_size x 77; output_dim: batch_size x 512
        
        fused_images_texts = torch.hstack((images,texts))
        fused_images_texts.requires_grad_()
        fused_images_texts = fused_images_texts.float()

        with torch.no_grad():
            output = net(fused_images_texts)

        loss = criterion(output.squeeze(), labels)

        running_loss += loss.item()

        correct_preds += sum(torch.sigmoid(output).squeeze().round()==labels)
        total_preds += len(labels)

        preds_all_val = torch.cat((preds_all_val, torch.sigmoid(output).squeeze()))
        labels_all_val = torch.cat((labels_all_val, labels))

    auroc = BinaryAUROC()
    auroc_score = auroc(preds_all_val, labels_all_val.int())
    print(f"\n[Epoch {epoch +1}, step {i+1:3d}] val loss: {running_loss/i+1:.5f} "
    f"accuracy: {torch.mean(preds_all_val.round()==labels_all_val, dtype=torch.float32)} auroc: {auroc_score}\n")
    # f"accuracy: {correct_preds/total_preds} auroc: {auroc_score}\n")

print("Finished Training!")

 18%|███████▌                                  | 50/279 [01:07<04:48,  1.26s/it]

[Epoch 1, step  50] loss: 0.56839


 36%|██████████████▋                          | 100/279 [02:10<03:49,  1.28s/it]

[Epoch 1, step 100] loss: 0.52300


 54%|██████████████████████                   | 150/279 [03:13<02:43,  1.26s/it]

[Epoch 1, step 150] loss: 0.49301


 72%|█████████████████████████████▍           | 200/279 [04:14<01:37,  1.24s/it]

[Epoch 1, step 200] loss: 0.47013


 90%|████████████████████████████████████▋    | 250/279 [05:18<00:36,  1.24s/it]

[Epoch 1, step 250] loss: 0.46732


100%|█████████████████████████████████████████| 279/279 [05:55<00:00,  1.27s/it]



[Epoch 1, step  17] val loss: 1.66176 accuracy: 0.664814829826355 auroc: 0.69667649269104



 18%|███████▌                                  | 50/279 [01:01<04:49,  1.26s/it]

[Epoch 2, step  50] loss: 0.38719


 36%|██████████████▋                          | 100/279 [02:03<03:45,  1.26s/it]

[Epoch 2, step 100] loss: 0.41231


 54%|██████████████████████                   | 150/279 [03:06<02:45,  1.28s/it]

[Epoch 2, step 150] loss: 0.38168


 72%|█████████████████████████████▍           | 200/279 [04:07<01:40,  1.27s/it]

[Epoch 2, step 200] loss: 0.40344


 90%|████████████████████████████████████▋    | 250/279 [05:10<00:38,  1.32s/it]

[Epoch 2, step 250] loss: 0.42593


100%|█████████████████████████████████████████| 279/279 [05:46<00:00,  1.24s/it]



[Epoch 2, step  17] val loss: 1.69077 accuracy: 0.7018518447875977 auroc: 0.7366912364959717



 18%|███████▌                                  | 50/279 [01:01<04:49,  1.26s/it]

[Epoch 3, step  50] loss: 0.32935


 36%|██████████████▋                          | 100/279 [02:03<03:10,  1.06s/it]

[Epoch 3, step 100] loss: 0.32473


 54%|██████████████████████                   | 150/279 [03:04<02:39,  1.24s/it]

[Epoch 3, step 150] loss: 0.35195


 72%|█████████████████████████████▍           | 200/279 [04:07<01:38,  1.25s/it]

[Epoch 3, step 200] loss: 0.35761


 90%|████████████████████████████████████▋    | 250/279 [05:11<00:37,  1.30s/it]

[Epoch 3, step 250] loss: 0.33693


100%|█████████████████████████████████████████| 279/279 [05:47<00:00,  1.24s/it]



[Epoch 3, step  17] val loss: 1.67247 accuracy: 0.690740704536438 auroc: 0.7481029033660889



 18%|███████▌                                  | 50/279 [01:02<04:49,  1.27s/it]

[Epoch 4, step  50] loss: 0.24119


 36%|██████████████▋                          | 100/279 [02:06<03:41,  1.24s/it]

[Epoch 4, step 100] loss: 0.26928


 54%|██████████████████████                   | 150/279 [03:08<02:42,  1.26s/it]

[Epoch 4, step 150] loss: 0.29046


 72%|█████████████████████████████▍           | 200/279 [04:11<01:43,  1.30s/it]

[Epoch 4, step 200] loss: 0.30532


 90%|████████████████████████████████████▋    | 250/279 [05:15<00:36,  1.25s/it]

[Epoch 4, step 250] loss: 0.29911


100%|█████████████████████████████████████████| 279/279 [05:49<00:00,  1.25s/it]



[Epoch 4, step  17] val loss: 1.66855 accuracy: 0.7111110687255859 auroc: 0.7541029453277588



 18%|███████▌                                  | 50/279 [00:52<03:57,  1.04s/it]

[Epoch 5, step  50] loss: 0.20172


 36%|██████████████▋                          | 100/279 [01:44<03:08,  1.06s/it]

[Epoch 5, step 100] loss: 0.20109


 54%|██████████████████████                   | 150/279 [02:36<02:15,  1.05s/it]

[Epoch 5, step 150] loss: 0.21513


 72%|█████████████████████████████▍           | 200/279 [03:28<01:21,  1.04s/it]

[Epoch 5, step 200] loss: 0.24404


 90%|████████████████████████████████████▋    | 250/279 [04:20<00:28,  1.00it/s]

[Epoch 5, step 250] loss: 0.21925


100%|█████████████████████████████████████████| 279/279 [04:51<00:00,  1.04s/it]



[Epoch 5, step  17] val loss: 1.88724 accuracy: 0.6981481313705444 auroc: 0.74664705991745



 18%|███████▌                                  | 50/279 [00:52<04:07,  1.08s/it]

[Epoch 6, step  50] loss: 0.15640


 36%|██████████████▋                          | 100/279 [01:44<03:07,  1.05s/it]

[Epoch 6, step 100] loss: 0.16892


 54%|██████████████████████                   | 150/279 [02:36<02:07,  1.01it/s]

[Epoch 6, step 150] loss: 0.17869


 72%|█████████████████████████████▍           | 200/279 [03:29<01:20,  1.02s/it]

[Epoch 6, step 200] loss: 0.18588


 90%|████████████████████████████████████▋    | 250/279 [04:20<00:28,  1.01it/s]

[Epoch 6, step 250] loss: 0.19298


100%|█████████████████████████████████████████| 279/279 [04:52<00:00,  1.05s/it]



[Epoch 6, step  17] val loss: 2.18095 accuracy: 0.699999988079071 auroc: 0.7497646808624268



 18%|███████▌                                  | 50/279 [01:01<04:42,  1.23s/it]

[Epoch 7, step  50] loss: 0.12919


 36%|██████████████▋                          | 100/279 [02:02<03:14,  1.09s/it]

[Epoch 7, step 100] loss: 0.13143


 54%|██████████████████████                   | 150/279 [03:00<02:39,  1.24s/it]

[Epoch 7, step 150] loss: 0.14094


 72%|█████████████████████████████▍           | 200/279 [04:01<01:32,  1.17s/it]

[Epoch 7, step 200] loss: 0.13993


 90%|████████████████████████████████████▋    | 250/279 [05:01<00:34,  1.18s/it]

[Epoch 7, step 250] loss: 0.13300


100%|█████████████████████████████████████████| 279/279 [05:36<00:00,  1.21s/it]



[Epoch 7, step  17] val loss: 2.04305 accuracy: 0.6814814805984497 auroc: 0.7513823509216309



 18%|███████▌                                  | 50/279 [01:01<04:39,  1.22s/it]

[Epoch 8, step  50] loss: 0.09396


 36%|██████████████▋                          | 100/279 [02:00<03:44,  1.25s/it]

[Epoch 8, step 100] loss: 0.10680


 54%|██████████████████████                   | 150/279 [03:01<02:21,  1.10s/it]

[Epoch 8, step 150] loss: 0.12594


 72%|█████████████████████████████▍           | 200/279 [04:03<01:38,  1.25s/it]

[Epoch 8, step 200] loss: 0.13318


 90%|████████████████████████████████████▋    | 250/279 [05:03<00:34,  1.20s/it]

[Epoch 8, step 250] loss: 0.14423


100%|█████████████████████████████████████████| 279/279 [05:39<00:00,  1.22s/it]



[Epoch 8, step  17] val loss: 2.61001 accuracy: 0.7074074149131775 auroc: 0.7447794079780579



 18%|███████▌                                  | 50/279 [00:58<04:35,  1.20s/it]

[Epoch 9, step  50] loss: 0.06488


 36%|██████████████▋                          | 100/279 [01:58<03:39,  1.23s/it]

[Epoch 9, step 100] loss: 0.09647


 54%|██████████████████████                   | 150/279 [02:59<02:35,  1.21s/it]

[Epoch 9, step 150] loss: 0.08751


 72%|█████████████████████████████▍           | 200/279 [04:00<01:36,  1.23s/it]

[Epoch 9, step 200] loss: 0.08256


 90%|████████████████████████████████████▋    | 250/279 [05:01<00:35,  1.23s/it]

[Epoch 9, step 250] loss: 0.07911


100%|█████████████████████████████████████████| 279/279 [05:37<00:00,  1.21s/it]



[Epoch 9, step  17] val loss: 2.45956 accuracy: 0.6851851940155029 auroc: 0.7418676614761353



 18%|███████▌                                  | 50/279 [01:01<04:41,  1.23s/it]

[Epoch 10, step  50] loss: 0.05890


 36%|██████████████▋                          | 100/279 [02:02<03:39,  1.22s/it]

[Epoch 10, step 100] loss: 0.08566


 54%|██████████████████████                   | 150/279 [03:02<02:32,  1.18s/it]

[Epoch 10, step 150] loss: 0.07540


 72%|█████████████████████████████▍           | 200/279 [04:02<01:26,  1.09s/it]

[Epoch 10, step 200] loss: 0.06857


 90%|████████████████████████████████████▋    | 250/279 [05:02<00:34,  1.20s/it]

[Epoch 10, step 250] loss: 0.08492


100%|█████████████████████████████████████████| 279/279 [05:38<00:00,  1.21s/it]



[Epoch 10, step  17] val loss: 2.20992 accuracy: 0.6870370507240295 auroc: 0.736544132232666



 18%|███████▌                                  | 50/279 [01:02<04:47,  1.26s/it]

[Epoch 11, step  50] loss: 0.08107


 36%|██████████████▋                          | 100/279 [02:03<03:35,  1.20s/it]

[Epoch 11, step 100] loss: 0.08152


 54%|██████████████████████                   | 150/279 [03:04<02:35,  1.21s/it]

[Epoch 11, step 150] loss: 0.07019


 72%|█████████████████████████████▍           | 200/279 [04:03<01:36,  1.23s/it]

[Epoch 11, step 200] loss: 0.09393


 90%|████████████████████████████████████▋    | 250/279 [05:04<00:35,  1.21s/it]

[Epoch 11, step 250] loss: 0.06127


100%|█████████████████████████████████████████| 279/279 [05:39<00:00,  1.22s/it]



[Epoch 11, step  17] val loss: 2.34910 accuracy: 0.6870370507240295 auroc: 0.7356323599815369



 18%|███████▌                                  | 50/279 [01:01<04:33,  1.20s/it]

[Epoch 12, step  50] loss: 0.06756


 36%|██████████████▋                          | 100/279 [02:01<03:37,  1.22s/it]

[Epoch 12, step 100] loss: 0.07118


 54%|██████████████████████                   | 150/279 [03:03<02:41,  1.25s/it]

[Epoch 12, step 150] loss: 0.08598


 72%|█████████████████████████████▍           | 200/279 [04:05<01:40,  1.28s/it]

[Epoch 12, step 200] loss: 0.07506


 90%|████████████████████████████████████▋    | 250/279 [05:04<00:29,  1.03s/it]

[Epoch 12, step 250] loss: 0.09445


100%|█████████████████████████████████████████| 279/279 [05:37<00:00,  1.21s/it]



[Epoch 12, step  17] val loss: 2.65943 accuracy: 0.7166666388511658 auroc: 0.7426618337631226



 18%|███████▌                                  | 50/279 [00:54<03:58,  1.04s/it]

[Epoch 13, step  50] loss: 0.05115


 36%|██████████████▋                          | 100/279 [01:52<03:26,  1.16s/it]

[Epoch 13, step 100] loss: 0.06665


 54%|██████████████████████                   | 150/279 [02:47<02:16,  1.06s/it]

[Epoch 13, step 150] loss: 0.08688


 72%|█████████████████████████████▍           | 200/279 [03:40<01:19,  1.00s/it]

[Epoch 13, step 200] loss: 0.08942


 90%|████████████████████████████████████▋    | 250/279 [04:35<00:32,  1.14s/it]

[Epoch 13, step 250] loss: 0.07720


100%|█████████████████████████████████████████| 279/279 [05:05<00:00,  1.10s/it]



[Epoch 13, step  17] val loss: 2.61954 accuracy: 0.7055555582046509 auroc: 0.7530882358551025



 18%|███████▌                                  | 50/279 [00:52<03:56,  1.03s/it]

[Epoch 14, step  50] loss: 0.04428


 36%|██████████████▋                          | 100/279 [01:46<03:12,  1.08s/it]

[Epoch 14, step 100] loss: 0.03295


 54%|██████████████████████                   | 150/279 [02:42<02:08,  1.01it/s]

[Epoch 14, step 150] loss: 0.05554


 72%|█████████████████████████████▍           | 200/279 [03:37<01:39,  1.25s/it]

[Epoch 14, step 200] loss: 0.05653


 90%|████████████████████████████████████▋    | 250/279 [04:36<00:35,  1.21s/it]

[Epoch 14, step 250] loss: 0.06004


100%|█████████████████████████████████████████| 279/279 [05:12<00:00,  1.12s/it]



[Epoch 14, step  17] val loss: 2.56722 accuracy: 0.7277777791023254 auroc: 0.7503896951675415



 18%|███████▌                                  | 50/279 [01:00<04:33,  1.20s/it]

[Epoch 15, step  50] loss: 0.02045


 36%|██████████████▋                          | 100/279 [02:01<03:36,  1.21s/it]

[Epoch 15, step 100] loss: 0.02982


 54%|██████████████████████                   | 150/279 [02:59<02:16,  1.06s/it]

[Epoch 15, step 150] loss: 0.07793


 72%|█████████████████████████████▍           | 200/279 [04:00<01:37,  1.23s/it]

[Epoch 15, step 200] loss: 0.07396


 90%|████████████████████████████████████▋    | 250/279 [05:01<00:35,  1.22s/it]

[Epoch 15, step 250] loss: 0.05920


100%|█████████████████████████████████████████| 279/279 [05:37<00:00,  1.21s/it]



[Epoch 15, step  17] val loss: 2.68049 accuracy: 0.7111110687255859 auroc: 0.7574338316917419



 18%|███████▌                                  | 50/279 [00:59<04:22,  1.14s/it]

[Epoch 16, step  50] loss: 0.07295


 36%|██████████████▋                          | 100/279 [01:52<03:06,  1.04s/it]

[Epoch 16, step 100] loss: 0.06701


 54%|██████████████████████                   | 150/279 [02:44<02:27,  1.14s/it]

[Epoch 16, step 150] loss: 0.07662


 72%|█████████████████████████████▍           | 200/279 [03:36<01:24,  1.07s/it]

[Epoch 16, step 200] loss: 0.05461


 90%|████████████████████████████████████▋    | 250/279 [04:29<00:29,  1.01s/it]

[Epoch 16, step 250] loss: 0.09634


100%|█████████████████████████████████████████| 279/279 [05:00<00:00,  1.08s/it]



[Epoch 16, step  17] val loss: 3.02025 accuracy: 0.699999988079071 auroc: 0.7499337792396545



 18%|███████▌                                  | 50/279 [00:51<03:46,  1.01it/s]

[Epoch 17, step  50] loss: 0.04751


 36%|██████████████▋                          | 100/279 [01:41<02:56,  1.02it/s]

[Epoch 17, step 100] loss: 0.08830


 54%|██████████████████████                   | 150/279 [02:33<02:23,  1.11s/it]

[Epoch 17, step 150] loss: 0.06510


 72%|█████████████████████████████▍           | 200/279 [03:23<01:17,  1.02it/s]

[Epoch 17, step 200] loss: 0.06323


 90%|████████████████████████████████████▋    | 250/279 [04:14<00:29,  1.00s/it]

[Epoch 17, step 250] loss: 0.08161


100%|█████████████████████████████████████████| 279/279 [04:44<00:00,  1.02s/it]



[Epoch 17, step  17] val loss: 2.60270 accuracy: 0.6944444179534912 auroc: 0.7360882759094238



 18%|███████▌                                  | 50/279 [00:52<04:23,  1.15s/it]

[Epoch 18, step  50] loss: 0.04234


 36%|██████████████▋                          | 100/279 [01:43<02:59,  1.00s/it]

[Epoch 18, step 100] loss: 0.05125


 54%|██████████████████████                   | 150/279 [02:35<02:09,  1.01s/it]

[Epoch 18, step 150] loss: 0.04661


 72%|█████████████████████████████▍           | 200/279 [03:26<01:19,  1.01s/it]

[Epoch 18, step 200] loss: 0.05792


 90%|████████████████████████████████████▋    | 250/279 [04:16<00:29,  1.01s/it]

[Epoch 18, step 250] loss: 0.06823


100%|█████████████████████████████████████████| 279/279 [04:46<00:00,  1.03s/it]



[Epoch 18, step  17] val loss: 2.36051 accuracy: 0.7055555582046509 auroc: 0.7328308820724487



 18%|███████▌                                  | 50/279 [00:50<03:49,  1.00s/it]

[Epoch 19, step  50] loss: 0.02153


 36%|██████████████▋                          | 100/279 [01:42<03:06,  1.04s/it]

[Epoch 19, step 100] loss: 0.04844


 54%|██████████████████████                   | 150/279 [02:34<02:07,  1.02it/s]

[Epoch 19, step 150] loss: 0.05362


 72%|█████████████████████████████▍           | 200/279 [03:25<01:20,  1.02s/it]

[Epoch 19, step 200] loss: 0.06858


 90%|████████████████████████████████████▋    | 250/279 [04:16<00:29,  1.02s/it]

[Epoch 19, step 250] loss: 0.05105


100%|█████████████████████████████████████████| 279/279 [04:46<00:00,  1.03s/it]



[Epoch 19, step  17] val loss: 2.23323 accuracy: 0.6981481313705444 auroc: 0.7377352714538574



 18%|███████▌                                  | 50/279 [00:51<03:45,  1.01it/s]

[Epoch 20, step  50] loss: 0.04249


 36%|██████████████▋                          | 100/279 [01:42<03:03,  1.03s/it]

[Epoch 20, step 100] loss: 0.02403


 54%|██████████████████████                   | 150/279 [02:33<02:11,  1.02s/it]

[Epoch 20, step 150] loss: 0.02944


 72%|█████████████████████████████▍           | 200/279 [03:24<01:22,  1.04s/it]

[Epoch 20, step 200] loss: 0.03239


 90%|████████████████████████████████████▋    | 250/279 [04:14<00:28,  1.02it/s]

[Epoch 20, step 250] loss: 0.05485


100%|█████████████████████████████████████████| 279/279 [04:44<00:00,  1.02s/it]



[Epoch 20, step  17] val loss: 2.42863 accuracy: 0.6851851940155029 auroc: 0.7346985340118408

Finished Training!


In [10]:
correct_preds = 0
total_preds = 0

preds_all_val = torch.tensor([]).cuda()
labels_all_val = torch.tensor([]).cuda()
for i, data in enumerate(test_dataloader, 0):
    images, texts, labels = data
    images = images.to(device)
    texts = texts.to(device)
    labels = labels.float().squeeze().to(device)

    with torch.no_grad():
        images = model.encode_image(images) # input_dim: batch_size x 3 x H x W; output_dim: batch_size x 512
        texts = model.encode_text(texts.squeeze()) # input_dim: batch_size x 77; output_dim: batch_size x 512
    
    fused_images_texts = torch.hstack((images,texts))
    fused_images_texts.requires_grad_()
    fused_images_texts = fused_images_texts.float()

    with torch.no_grad():
        output = net(fused_images_texts)

    loss = criterion(output.squeeze(), labels)

    running_loss += loss.item()

    correct_preds += sum(torch.sigmoid(output).squeeze().round()==labels)
    total_preds += len(labels)

    preds_all_val = torch.cat((preds_all_val, torch.sigmoid(output).squeeze()))
    labels_all_val = torch.cat((labels_all_val, labels))

auroc_score = auroc(preds_all_val, labels_all_val.int())
print(f"accuracy: {torch.mean(preds_all_val.round()==labels_all_val, dtype=torch.float32)} auroc: {auroc_score}\n")



accuracy: 0.733500063419342 auroc: 0.79988694190979

