In [4]:
import torch
import torch.nn as nn
import math

from tqdm import tqdm
from helper import get_transformer_input, save_model, rank
from datasets import Recipe1MDataset
from models import TextEncoder, ImageEncoder
from trainer import train
from helper import freeze_params
from torch.utils.data import DataLoader
from transformers import BertTokenizer

In [14]:
saved_model_path = 'saved_models/model.pt'
saved_weights = torch.load(saved_model_path, map_location='cpu')

#transformer_model_path = '/common/home/as3503/as3503/courses/cs536/final_project/final_project/saved_models/1s0qc5ue/model_train_encoders_False_epoch_0.pt'

#transformer_weights = torch.load(transformer_model_path, map_location='cpu')
device = 'cuda:1'
text_encoder = TextEncoder(2, 2)
text_encoder.load_state_dict(saved_weights['txt_encoder'])
text_encoder = text_encoder.to(device)
image_encoder = ImageEncoder()
image_encoder.load_state_dict(saved_weights['img_encoder'])
image_encoder = image_encoder.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
train_dataset = Recipe1MDataset(part='train')
val_dataset = Recipe1MDataset(part='val')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
save_dir = 'saved_models/'

freeze_params(text_encoder)
freeze_params(image_encoder)

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In Recipe1MDataset
In Recipe1MDataset


In [36]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        pe = torch.transpose(pe, 0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[0, :x.size(1), :]
        return self.dropout(x)


class CrossModalAttention(nn.Module):
    def __init__(self, model_dim=768, n_heads=2, n_layers=2, num_image_patches=197, num_classes=2, drop_rate=0.1):
        super().__init__()
        self.text_pos_embed = SinusoidalPositionalEncoding(model_dim, dropout=drop_rate)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        self.sep_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        self.image_pos_embed = nn.Parameter(torch.zeros(1, num_image_patches + 1, model_dim))
        self.image_pos_drop = nn.Dropout(p=drop_rate)
        layers = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=n_heads,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(layers, num_layers=n_layers)
        self.cls_projection = nn.Linear(model_dim, num_classes)
        
    def forward(self, image_features, text_features, src_key_padding_mask=None):
        batch_size = image_features.shape[0]
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        image_features = torch.cat((cls_token, image_features), dim=1)
        image_features = image_features + self.image_pos_embed
        image_features = self.image_pos_drop(image_features)
        
        text_features = self.text_pos_embed(text_features)
        
        sep_token = self.sep_token.expand(batch_size, -1, -1)
        transformer_input = torch.cat((image_features, sep_token, text_features), dim=1)
        if src_key_padding_mask is not None:
            src_key_padding_mask = torch.cat((torch.zeros(batch_size, image_features.shape[1] + 1).to(transformer_input.device), src_key_padding_mask), dim=1)
        transformer_outputs = self.encoder(transformer_input, src_key_padding_mask=src_key_padding_mask)
        cls_outputs = transformer_outputs[:, 0, :]
        return self.cls_projection(cls_outputs)

In [37]:

# cm_transformer = CrossModalAttention().to(device)

cm_transformer = CrossModalAttention()
#cm_transformer.load_state_dict(transformer_weights['cm_transformer'])
cm_transformer = cm_transformer.to(device)

In [38]:
criterion = nn.CrossEntropyLoss()
train_encoders = False
if train_encoders:
    optimizer = torch.optim.Adam(
        [
            {'params': image_encoder.parameters()},
            {'params': text_encoder.parameters()},
            {'params': cm_transformer.parameters()}
        ],
        lr=1e-5
    )
else:
    optimizer = torch.optim.Adam(cm_transformer.parameters(), lr=1e-5)

In [40]:
cm_transformer.train()

train_loss, total_samples = 0, 0
num_its = 0

for text, image in tqdm(train_loader):
    
    num_its += 1
    text_inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
    text_outputs = text_encoder(**text_inputs)
    image_outputs = image_encoder(image.to(device))
    transformer_image_inputs, transformer_text_inputs, output_attention_mask, ground_truth = \
        get_transformer_input(image_outputs, text_outputs, text_inputs.attention_mask)
    text_padding_mask = ~output_attention_mask.bool()
    outputs = cm_transformer(transformer_image_inputs.to(device), transformer_text_inputs.to(device), text_padding_mask.to(device))
    loss = criterion(outputs, ground_truth.to(device).long())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    train_loss += loss.item() * image.shape[0]
    total_samples += image.shape[0]

    if num_its % 100 == 0:
        print('Train loss', round(train_loss / total_samples, 4))
    
    if num_its % 1000 == 0:
        print('Saving model')
        import os
        save_dict = {
            'cm_transformer': cm_transformer.state_dict()
        }
        save_dir = 'saved_models'
        save_model(save_dict, fpath=os.path.join(save_dir, f'temp_model.pt'))

  0%|▏                                                                                                                                | 100/56541 [00:24<3:41:12,  4.25it/s]

Train loss 0.1208


  0%|▍                                                                                                                                | 200/56541 [00:50<3:43:15,  4.21it/s]

Train loss 0.0677


  1%|▋                                                                                                                                | 300/56541 [01:15<4:20:26,  3.60it/s]

Train loss 0.0472


  1%|▉                                                                                                                                | 401/56541 [01:40<3:32:25,  4.40it/s]

Train loss 0.0362


  1%|█▏                                                                                                                               | 500/56541 [02:05<3:28:40,  4.48it/s]

Train loss 0.0299


  1%|█▎                                                                                                                               | 600/56541 [02:31<3:36:15,  4.31it/s]

Train loss 0.0256


  1%|█▌                                                                                                                               | 700/56541 [03:10<7:34:11,  2.05it/s]

Train loss 0.0226


  1%|█▊                                                                                                                               | 800/56541 [03:52<6:15:09,  2.48it/s]

Train loss 0.02


  2%|██                                                                                                                               | 900/56541 [04:33<6:34:57,  2.35it/s]

Train loss 0.0219


  2%|██▎                                                                                                                              | 999/56541 [05:12<4:49:42,  3.20it/s]

Train loss 0.0201
Saving model





NameError: name 'os' is not defined

In [47]:
image_encoder.eval()
text_encoder.eval()
cm_transformer.eval()
softmax = nn.Softmax(dim=-1)

val_loss, total_samples = 0, 0
for text, image in tqdm(val_loader):
    text_inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
    text_outputs = text_encoder(**text_inputs)
    image_outputs = image_encoder(image.to(device))
    transformer_image_inputs, transformer_text_inputs, output_attention_mask, ground_truth = \
        get_transformer_input(image_outputs, text_outputs, text_inputs.attention_mask)
    text_padding_mask = ~output_attention_mask.bool()
    outputs = cm_transformer(transformer_image_inputs[:8].to(device), transformer_text_inputs[:8].to(device),
                             text_padding_mask[:8].to(device))
#     loss = criterion(outputs, ground_truth.to(device).long())

    print(outputs)
    print(softmax(outputs))

    break

  0%|                                                                                                                                             | 0/12148 [00:00<?, ?it/s]

tensor([[-6.5810,  6.2637],
        [-6.3867,  5.9277],
        [-5.7426,  5.3994],
        [-6.7491,  5.8510],
        [-5.2724,  5.1299],
        [-5.8812,  5.6835],
        [-5.9786,  5.4465],
        [-6.5578,  5.7299]], device='cuda:1', grad_fn=<AddmmBackward0>)
tensor([[2.6403e-06, 1.0000e+00],
        [4.4867e-06, 1.0000e+00],
        [1.4491e-05, 9.9999e-01],
        [3.3718e-06, 1.0000e+00],
        [3.0360e-05, 9.9997e-01],
        [9.4950e-06, 9.9999e-01],
        [1.0918e-05, 9.9999e-01],
        [4.6081e-06, 1.0000e+00]], device='cuda:1', grad_fn=<SoftmaxBackward0>)





In [50]:
image_encoder.eval()
text_encoder.eval()
cm_transformer.eval()
softmax = nn.Softmax(dim=-1)

val_loss, total_samples = 0, 0
for text, image in tqdm(val_loader):
    text_inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
    text_outputs = text_encoder(**text_inputs)
    image_outputs = image_encoder(image.to(device))
    transformer_image_inputs, transformer_text_inputs, output_attention_mask, ground_truth = \
        get_transformer_input(image_outputs, text_outputs, text_inputs.attention_mask)
    text_padding_mask = ~output_attention_mask.bool()
    
    indices = torch.randperm(transformer_image_inputs.size()[0])
#     outputs = cm_transformer(transformer_image_inputs[indices].to(device), transformer_text_inputs[indices].to(device), text_padding_mask[indices].to(device))
        
        
    outputs = cm_transformer(transformer_image_inputs[indices].to(device), transformer_text_inputs[indices].to(device),
                             text_padding_mask[indices].to(device))
#     loss = criterion(outputs, ground_truth.to(device).long())

    print(outputs)
    print(softmax(outputs))
    print(ground_truth[indices])

    break

  0%|                                                                                                                                             | 0/12148 [00:00<?, ?it/s]

tensor([[-5.7622,  5.5879],
        [-6.4880,  6.0461],
        [ 5.6285, -5.5787],
        [-5.9732,  5.6619],
        [-6.4167,  6.1035],
        [ 7.0654, -6.5679],
        [ 5.2722, -5.1182],
        [ 6.6224, -5.9077],
        [-6.7211,  5.7733],
        [ 5.0049, -5.4297],
        [ 6.6561, -6.0565],
        [ 5.7452, -5.7584],
        [ 6.0133, -5.3118],
        [ 7.0897, -6.4577],
        [-6.0918,  5.5260],
        [ 4.9967, -4.9360],
        [ 7.3340, -6.4438],
        [ 5.9131, -5.8019],
        [ 6.1362, -6.1016],
        [ 5.6348, -5.7662],
        [-6.8971,  6.0547],
        [ 6.6107, -6.1656],
        [-5.0777,  4.9768],
        [ 6.0133, -5.3118]], device='cuda:1', grad_fn=<AddmmBackward0>)
tensor([[1.1768e-05, 9.9999e-01],
        [3.6017e-06, 1.0000e+00],
        [9.9999e-01, 1.3575e-05],
        [8.8499e-06, 9.9999e-01],
        [3.6521e-06, 1.0000e+00],
        [1.0000e+00, 1.1998e-06],
        [9.9997e-01, 3.0725e-05],
        [1.0000e+00, 3.6161e-06],
        [3.7




In [None]:
min_val_loss = float('inf')
project_name = 'cross_modal_attention'
wandb.init(project=project_name, entity='cs536')
save_dir = os.path.join(save_dir, wandb.run.id)
os.makedirs(save_dir, exist_ok=True)
criterion = nn.CrossEntropyLoss()
if train_encoders:
    optimizer = torch.optim.Adam(
        [
            {'params': image_encoder.parameters()},
            {'params': text_encoder.parameters()},
            {'params': cm_transformer.parameters()}
        ],
        lr=lr
    )
else:
    optimizer = torch.optim.Adam(cm_transformer.parameters(), lr=lr)


for epoch in range(num_epochs):
    train_loss = train_one_epoch(image_encoder, text_encoder, cm_transformer, train_dataloader, 
                                tokenizer, criterion, optimizer, train_encoders, device)
    val_loss = evaluate(image_encoder, text_encoder, cm_transformer, 
                        val_dataloader, tokenizer, criterion, device)

    # if val_loss < min_val_loss:
    min_val_loss = val_loss
    if train_encoders:
        save_dict = {
            'image_encoder': image_encoder.state_dict(),
            'text_encoder': text_encoder.state_dict(),
            'cm_transformer': cm_transformer.state_dict()
        }
    else:
        save_dict = {
            'cm_transformer': cm_transformer.state_dict()
        }

    save_dict['train_loss'] = train_loss
    save_dict['val_loss'] = val_loss

    save_model(save_dict, fpath=os.path.join(save_dir, f'model_train_encoders_{train_encoders}_epoch_{epoch}.pt'))

    print(f'Epoch: {epoch}, Train loss: {train_loss}, Val loss: {val_loss}')