In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import VisualBertModel, VisualBertConfig, BertTokenizer
from sklearn.model_selection import train_test_split
import pickle
import tqdm
from sklearn.metrics import accuracy_score
import numpy as np

In [2]:
class Cifar100Dataset(Dataset):
    def __init__(self, visual_embeddings, labels, tokenizer):
        self.visual_embeddings = visual_embeddings
        self.labels = labels
        self.tokenizer = tokenizer
        self.num_labels = len(set(self.labels))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        visual_embedding = self.visual_embeddings[idx]
        # text_inputs = tokenizer(text_inputs, padding="max_length", max_length=32, truncation=True, return_tensors="pt")
        text_inputs = tokenizer("", padding="max_length", max_length=32, truncation=True, return_tensors="pt")
        
        # remove batch dimension
        for k,v in text_inputs.items():
            text_inputs[k] = v.squeeze()
            
        label = torch.zeros(self.num_labels)
        label[self.labels[idx]] = 1
        return visual_embedding, text_inputs, label

In [3]:
config = VisualBertConfig.from_pretrained("uclanlp/visualbert-nlvr2")
visual_bert = VisualBertModel(config)
print(visual_bert.config.visual_embedding_dim)

1024


In [4]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [5]:
class VisualBERTForImageClassification(nn.Module):
    def __init__(self, config, num_classes, visual_embedding_size):
        super().__init__()
        self.visual_bert = VisualBertModel(config)
        # self.projection = nn.Linear(256*7*7, visual_embedding_size * 256)
        # self.classifier = nn.Linear(config.hidden_size, num_classes)
        self.classifier = nn.Sequential(
            nn.Linear(in_features=config.hidden_size, out_features=1536, bias=True),
            nn.LayerNorm((1536,), eps=1e-05, elementwise_affine=True),
            nn.GELU(approximate='none'),
            nn.Linear(in_features=1536, out_features=num_classes, bias=True)
        )

    def forward(self, visual_embeddings, text_inputs):
#         # Project to the size that VisualBert expects
#         input_visual_embeddings = self.projection(visual_embeddings.view(visual_embeddings.size(0), -1))
        
#         # Reshape the visual embeddings back to 3D tensor
#         input_visual_embeddings = input_visual_embeddings.view(input_visual_embeddings.size(0), 256, -1)
                
        outputs = self.visual_bert(
            input_ids=text_inputs["input_ids"],
            attention_mask=text_inputs["attention_mask"],
            token_type_ids=text_inputs["token_type_ids"],
            visual_embeds=visual_embeddings
        )
        return self.classifier(outputs.pooler_output)

In [6]:
def collate_fn(batch):
        visual_embeddings, text_inputs, labels = zip(*batch)
        
        # Flatten the visual embeddings
        visual_embeddings = [embedding.view(embedding.size(0), -1) for embedding in visual_embeddings]
        
        visual_embeddings = torch.stack(visual_embeddings)
        
        labels = torch.stack(labels)
        
        input_ids = torch.stack([item['input_ids'] for item in text_inputs])
        attention_mask = torch.stack([item['attention_mask'] for item in text_inputs])
        token_type_ids = torch.stack([item['token_type_ids'] for item in text_inputs])
        
        text_inputs = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids
        }
        
                
        return visual_embeddings, text_inputs, labels

In [7]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    for visual_embeddings, text_inputs, labels in dataloader:
        visual_embeddings = visual_embeddings.to(device)
        text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(visual_embeddings, text_inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

In [8]:
with open('../cifar100-train-embeddings.pkl', 'rb') as f:
    train_embeddings_all = pickle.load(f)

with open('../cifar100-test-embeddings.pkl', 'rb') as f:
    test_embeddings_all = pickle.load(f)

In [9]:
def curtail_embeddings(embeddings):
    lowest = 1000
    
    for e in embeddings['embeddings']:
        lowest = min(lowest, e.size(0))
    
    for idx, e in enumerate(embeddings['embeddings']):
        embeddings['embeddings'][idx] = e[:lowest]

In [10]:
curtail_embeddings(train_embeddings_all)
curtail_embeddings(test_embeddings_all)

In [30]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_embeddings = train_embeddings_all['embeddings']
train_labels = train_embeddings_all['fine_labels']

val_embeddings = test_embeddings_all['embeddings']
val_labels = test_embeddings_all['fine_labels']
num_classes = len(set(train_labels))

In [32]:
model = VisualBERTForImageClassification(config, num_classes=num_classes, visual_embedding_size=1024)
model = model.to(device)

In [33]:
model

VisualBERTForImageClassification(
  (visual_bert): VisualBertModel(
    (embeddings): VisualBertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=1)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (visual_token_type_embeddings): Embedding(2, 768)
      (visual_position_embeddings): Embedding(512, 768)
      (visual_projection): Linear(in_features=1024, out_features=768, bias=True)
    )
    (encoder): VisualBertEncoder(
      (layer): ModuleList(
        (0): VisualBertLayer(
          (attention): VisualBertAttention(
            (self): VisualBertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
       

In [34]:
train_dataset = Cifar100Dataset(train_embeddings, train_labels, tokenizer)
val_dataset = Cifar100Dataset(val_embeddings, val_labels, tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=collate_fn)

In [19]:
for name, param in model.named_parameters():
    if 'classifier' in name or 'pooler' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

In [20]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 20

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.25)

In [21]:
best_params = None
best_val_accuracy = -1

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    
    for visual_embeddings, text_inputs, labels in tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        visual_embeddings = visual_embeddings.to(device)
        text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
        labels = labels.to(device)

        optimizer.zero_grad()
        
        outputs = model(visual_embeddings, text_inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    scheduler.step()
    
    train_loss = total_loss / len(train_dataloader)
    
    # Evaluate the model on the validation set
    model.eval()
    val_predictions = []
    val_labels = []
    
    with torch.no_grad():
        for visual_embeddings, text_inputs, labels in val_dataloader:
            visual_embeddings = visual_embeddings.to(device)
            text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
            labels = labels.to(device)
            
            outputs = model(visual_embeddings, text_inputs)
            _, preds = torch.max(outputs, 1)
            
            val_predictions.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    val_labels_idx = [np.argmax(tensor) for tensor in val_labels]
    val_accuracy = accuracy_score(val_labels_idx, val_predictions)
    
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

#     # Save a checkpoint if validation accuracy is greater than the previous best validation accuracy
#     if val_accuracy > best_val_accuracy:
#         best_val_accuracy = val_accuracy
#         best_params = model.state_dict()
#         torch.save(best_params, "best_model.pth")

Epoch 1/20: 100%|██████████| 391/391 [01:17<00:00,  5.06it/s]


Epoch 1/20, Loss: 3.7687, Validation Accuracy: 0.2400


Epoch 2/20: 100%|██████████| 391/391 [01:16<00:00,  5.12it/s]


Epoch 2/20, Loss: 3.1842, Validation Accuracy: 0.2774


Epoch 3/20: 100%|██████████| 391/391 [01:13<00:00,  5.30it/s]


Epoch 3/20, Loss: 3.0307, Validation Accuracy: 0.3042


Epoch 4/20: 100%|██████████| 391/391 [01:12<00:00,  5.36it/s]


Epoch 4/20, Loss: 2.9477, Validation Accuracy: 0.3124


Epoch 5/20: 100%|██████████| 391/391 [01:12<00:00,  5.39it/s]


Epoch 5/20, Loss: 2.8935, Validation Accuracy: 0.3231


Epoch 6/20: 100%|██████████| 391/391 [01:19<00:00,  4.93it/s]


Epoch 6/20, Loss: 2.8545, Validation Accuracy: 0.3300


Epoch 7/20: 100%|██████████| 391/391 [01:13<00:00,  5.30it/s]


Epoch 7/20, Loss: 2.8223, Validation Accuracy: 0.3418


Epoch 8/20: 100%|██████████| 391/391 [01:13<00:00,  5.33it/s]


Epoch 8/20, Loss: 2.7906, Validation Accuracy: 0.3475


Epoch 9/20: 100%|██████████| 391/391 [01:13<00:00,  5.32it/s]


Epoch 9/20, Loss: 2.7600, Validation Accuracy: 0.3471


Epoch 10/20: 100%|██████████| 391/391 [01:05<00:00,  5.95it/s]


Epoch 10/20, Loss: 2.7356, Validation Accuracy: 0.3521


Epoch 11/20: 100%|██████████| 391/391 [00:53<00:00,  7.27it/s]


Epoch 11/20, Loss: 2.6786, Validation Accuracy: 0.3627


Epoch 12/20: 100%|██████████| 391/391 [00:58<00:00,  6.63it/s]


Epoch 12/20, Loss: 2.6644, Validation Accuracy: 0.3641


Epoch 13/20: 100%|██████████| 391/391 [01:38<00:00,  3.95it/s]


Epoch 13/20, Loss: 2.6588, Validation Accuracy: 0.3668


Epoch 14/20: 100%|██████████| 391/391 [01:39<00:00,  3.92it/s]


Epoch 14/20, Loss: 2.6509, Validation Accuracy: 0.3671


Epoch 15/20: 100%|██████████| 391/391 [01:40<00:00,  3.90it/s]


Epoch 15/20, Loss: 2.6444, Validation Accuracy: 0.3676


Epoch 16/20: 100%|██████████| 391/391 [01:40<00:00,  3.91it/s]


Epoch 16/20, Loss: 2.6400, Validation Accuracy: 0.3702


Epoch 17/20: 100%|██████████| 391/391 [01:39<00:00,  3.91it/s]


Epoch 17/20, Loss: 2.6334, Validation Accuracy: 0.3691


Epoch 18/20: 100%|██████████| 391/391 [01:40<00:00,  3.90it/s]


Epoch 18/20, Loss: 2.6291, Validation Accuracy: 0.3685


Epoch 19/20: 100%|██████████| 391/391 [01:40<00:00,  3.91it/s]


Epoch 19/20, Loss: 2.6267, Validation Accuracy: 0.3731


Epoch 20/20: 100%|██████████| 391/391 [01:39<00:00,  3.91it/s]


Epoch 20/20, Loss: 2.6177, Validation Accuracy: 0.3725


I think there was a bug, I didn't step through to the pooler? Try fixing this and also adjusting LR schedule

In [37]:
model = VisualBERTForImageClassification(config, num_classes=num_classes, visual_embedding_size=1024)
model = model.to(device)

train_dataset = Cifar100Dataset(train_embeddings, train_labels, tokenizer)
val_dataset = Cifar100Dataset(val_embeddings, val_labels, tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=collate_fn)

for name, param in model.named_parameters():
    if 'classifier' or 'pooler' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 20

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.25)


best_params = None
best_val_accuracy = -1

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    
    train_predictions = []
    train_labels_eval = []
    
    for visual_embeddings, text_inputs, labels in tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        visual_embeddings = visual_embeddings.to(device)
        text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
        labels = labels.to(device)

        optimizer.zero_grad()
        
        outputs = model(visual_embeddings, text_inputs)
        _, preds = torch.max(outputs, 1)
        
        train_predictions.extend(preds.cpu().numpy())
        train_labels_eval.extend(labels.cpu().numpy())
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    scheduler.step()
    
    train_loss = total_loss / len(train_dataloader)
    
    # Evaluate the model on the validation set
    model.eval()
    val_predictions = []
    val_labels_eval = []
        
    with torch.no_grad():
        for visual_embeddings, text_inputs, labels in val_dataloader:
            visual_embeddings = visual_embeddings.to(device)
            text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
            labels = labels.to(device)
            
            outputs = model(visual_embeddings, text_inputs)
            _, preds = torch.max(outputs, 1)
            
            val_predictions.extend(preds.cpu().numpy())
            val_labels_eval.extend(labels.cpu().numpy())
    
    val_labels_idx = [np.argmax(tensor) for tensor in val_labels_eval]
    val_accuracy = accuracy_score(val_labels_idx, val_predictions)
    
    train_labels_idx = [np.argmax(tensor) for tensor in train_labels_eval]
    train_accuracy = accuracy_score(train_labels_idx, train_predictions)
    
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}, Training Acc: {train_accuracy:.4f}, Validation Accuracy: {val_accuracy:.4f}")

#     # Save a checkpoint if validation accuracy is greater than the previous best validation accuracy
#     if val_accuracy > best_val_accuracy:
#         best_val_accuracy = val_accuracy
#         best_params = model.state_dict()
#         torch.save(best_params, "best_model.pth")

Epoch 1/20: 100%|██████████| 391/391 [02:16<00:00,  2.87it/s]


Epoch 1/20, Loss: 3.3235, Validation Accuracy: 0.2420


Epoch 2/20: 100%|██████████| 391/391 [02:16<00:00,  2.86it/s]


Epoch 2/20, Loss: 2.6895, Validation Accuracy: 0.3467


Epoch 3/20: 100%|██████████| 391/391 [03:29<00:00,  1.86it/s]


Epoch 3/20, Loss: 2.3941, Validation Accuracy: 0.3770


Epoch 4/20: 100%|██████████| 391/391 [04:15<00:00,  1.53it/s]


Epoch 4/20, Loss: 2.1592, Validation Accuracy: 0.3929


Epoch 5/20: 100%|██████████| 391/391 [04:16<00:00,  1.53it/s]


Epoch 5/20, Loss: 1.9511, Validation Accuracy: 0.4338


Epoch 6/20: 100%|██████████| 391/391 [02:38<00:00,  2.47it/s]


Epoch 6/20, Loss: 1.7538, Validation Accuracy: 0.4495


Epoch 7/20: 100%|██████████| 391/391 [02:15<00:00,  2.89it/s]


Epoch 7/20, Loss: 1.5708, Validation Accuracy: 0.4577


Epoch 8/20: 100%|██████████| 391/391 [02:16<00:00,  2.87it/s]


Epoch 8/20, Loss: 1.3681, Validation Accuracy: 0.4663


Epoch 9/20: 100%|██████████| 391/391 [02:16<00:00,  2.86it/s]


Epoch 9/20, Loss: 1.1658, Validation Accuracy: 0.4801


Epoch 10/20: 100%|██████████| 391/391 [02:17<00:00,  2.85it/s]


Epoch 10/20, Loss: 0.9839, Validation Accuracy: 0.4779


Epoch 11/20: 100%|██████████| 391/391 [02:23<00:00,  2.72it/s]


Epoch 11/20, Loss: 0.8019, Validation Accuracy: 0.4854


Epoch 12/20: 100%|██████████| 391/391 [04:18<00:00,  1.51it/s]


Epoch 12/20, Loss: 0.6786, Validation Accuracy: 0.4841


Epoch 13/20:  21%|██        | 82/391 [00:55<03:27,  1.49it/s]


KeyboardInterrupt: 

Try with pooler still frozen

In [None]:
model = VisualBERTForImageClassification(config, num_classes=num_classes, visual_embedding_size=1024)
model = model.to(device)

train_dataset = Cifar100Dataset(train_embeddings, train_labels, tokenizer)
val_dataset = Cifar100Dataset(val_embeddings, val_labels, tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=collate_fn)

for name, param in model.named_parameters():
    if 'classifier' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 20

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.25)


best_params = None
best_val_accuracy = -1

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    
    for visual_embeddings, text_inputs, labels in tqdm.tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        visual_embeddings = visual_embeddings.to(device)
        text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
        labels = labels.to(device)

        optimizer.zero_grad()
        
        outputs = model(visual_embeddings, text_inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    scheduler.step()
    
    train_loss = total_loss / len(train_dataloader)
    
    # Evaluate the model on the validation set
    model.eval()
    val_predictions = []
    val_labels = []
    
    with torch.no_grad():
        for visual_embeddings, text_inputs, labels in val_dataloader:
            visual_embeddings = visual_embeddings.to(device)
            text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
            labels = labels.to(device)
            
            outputs = model(visual_embeddings, text_inputs)
            _, preds = torch.max(outputs, 1)
            
            val_predictions.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    val_labels_idx = [np.argmax(tensor) for tensor in val_labels]
    val_accuracy = accuracy_score(val_labels_idx, val_predictions)
    
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

#     # Save a checkpoint if validation accuracy is greater than the previous best validation accuracy
#     if val_accuracy > best_val_accuracy:
#         best_val_accuracy = val_accuracy
#         best_params = model.state_dict()
#         torch.save(best_params, "best_model.pth")