In [1]:
import config as CFG

import torch
import torch.nn as nn
from transformers import BertModel, BertConfig
from transformers import ViTModel, ViTConfig
import wandb

In [2]:
### No modification same as in models.py

###################### TEXT TOWER ####################################

class TextEncoder(nn.Module):
    def __init__(self, model_name=CFG.text_model_name, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            self.model = BertModel.from_pretrained(model_name)
        else:

            self.model = BertModel(config=BertConfig.from_pretrained(model_name))
            
        for p in self.model.parameters():
            p.requires_grad = trainable

        # we are using the CLS token hidden representation as the sentence's embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]




###################### IMAGE TOWER ####################################


class ImageEncoder(nn.Module):
    def __init__(self, model_name=CFG.image_model_name, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        if pretrained:
            self.model = ViTModel.from_pretrained(model_name)
        else:
            self.model = ViTModel(config=ViTConfig.from_pretrained(model_name))
            
        for p in self.model.parameters():
            p.requires_grad = trainable

        self.target_token_idx = 0

    def forward(self, image):
        
        output = self.model(image)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]

###################### PROJECTION HEAD on top ####################################

class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

In [3]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg_loss,self.sum_loss, self.count = [0] * 3

    def update(self, loss, count=1):
        self.count += count
        self.sum_loss += loss * count
        self.avg_loss = self.sum_loss / self.count
        

    def __repr__(self):
        text = f"{self.name}: avg_loss = {self.avg_loss:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


In [4]:
from copy import deepcopy
from tqdm import tqdm

In [5]:
## Same as CLIP Projection, but implementing MOCO to be able to finetune both Text and Image tower as well, and keep a lot
# of negative contrastive exemples despite the smaller batch size

class CLIPProjMoco(nn.Module):
    def __init__(
        self,
        temperature=CFG.temperature,
        image_embedding=CFG.image_embedding,
        text_embedding=CFG.text_embedding,
        proj_dim = CFG.projection_dim,
        trainable=CFG.trainable,
        K=CFG.K,
        m=0.999
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.proj_dim = proj_dim
        self.temperature = temperature
        self.trainable = trainable

        # MOCO parameters
        self.K = K
        self.m = m

        # Init key encoders
        self.image_key_encoder = deepcopy(self.image_encoder)
        for param_k in self.image_key_encoder.parameters():param_k.requires_grad = False

        self.text_key_encoder = deepcopy(self.text_encoder)
        for param_k in self.image_key_encoder.parameters(): param_k.requires_grad = False

        self.image_key_projection = deepcopy(self.image_projection)
        for param_k in self.image_key_projection.parameters(): param_k.requires_grad = False

        self.text_key_projection = deepcopy(self.text_projection)
        for param_k in self.text_key_projection.parameters():param_k.requires_grad = False

        # Init Queues
        self.image_queue = torch.randn(self.K,self.proj_dim)
        self.text_queue = torch.randn(self.K,self.proj_dim)

        self.queue_ptr = 0

    def encode_text(self,text):
        if not self.trainable:
            with torch.no_grad():
                text_features = self.text_encoder(input_ids=text["input_ids"], attention_mask=text["attention_mask"])
        
        else:
            text_features = self.text_encoder(input_ids=text["input_ids"], attention_mask=text["attention_mask"])

        # Getting Text Embeddings (output of proj heads)
        text_embeddings = self.text_projection(text_features)

        return  text_embeddings
    
    def key_encode_text(self,text):
        if not self.trainable:
            with torch.no_grad():
                text_features = self.text_key_encoder(input_ids=text["input_ids"], attention_mask=text["attention_mask"])
        
        else:
            text_features = self.text_key_encoder(input_ids=text["input_ids"], attention_mask=text["attention_mask"])

        # Getting Text Embeddings (output of proj heads)
        text_embeddings = self.text_key_projection(text_features)

        return  text_embeddings

    def encode_image(self,image):
        if not self.trainable:
            with torch.no_grad():
                image_features = self.image_encoder(image)

        
        else:
            image_features = self.image_encoder(image)


        # Getting Image Embeddings (output of proj heads)
        image_embeddings = self.image_projection(image_features)


        return image_embeddings

    def key_encode_image(self,image):
        if not self.trainable:
            with torch.no_grad():
                image_features = self.image_key_encoder(image)

        
        else:
            image_features = self.image_key_encoder(image)


        # Getting Image Embeddings (output of proj heads)
        image_embeddings = self.image_key_projection(image_features)


        return image_embeddings

    ## Update all key parameters (both encoders and projection module)
    def _momentum_update_key_encoders(self):
        for param_q, param_k in zip(self.image_encoder.parameters(), self.image_key_encoder.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
            
        for param_q, param_k in zip(self.text_encoder.parameters(), self.text_key_encoder.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
    
        for param_q, param_k in zip(self.image_projection.parameters(), self.image_key_projection.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
            
        for param_q, param_k in zip(self.text_projection.parameters(), self.text_key_projection.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
    
    # Add new minibatch _k to queue and remove the oldest minibatch in queue
    def _dequeue_and_enqueue(self, image_k, text_k):
        bs = image_k.size(0)
        assert self.K % bs == 0  # for simplicity
        self.image_queue[self.queue_ptr:self.queue_ptr+bs, :] = image_k
        self.text_queue[self.queue_ptr:self.queue_ptr+bs, :] = text_k
        self.queue_ptr = (self.queue_ptr + bs) % self.K  # move pointer


    def forward(self, image,text):
      
        image_embeddings = self.encode_image(image)
        text_embeddings = self.encode_text(text)

        return {"image_embed": image_embeddings, "text_embed": text_embeddings}

In [6]:
def train_one_MOCO_epoch(model, loss_fn, train_loader, optimizer,device):
    
    loss_meter = AvgMeter()

    tqdm_object = tqdm(train_loader, total=len(train_loader))
    
    

    

    
    for batch in tqdm_object:

        image = batch["image"].to(device)
        text = {"input_ids": batch["input_ids"].to(device), "attention_mask": batch["attention_mask"].to(device)}
        
        # Update the momentum encoder
        # Generate key for this batch, and update the queue
        with torch.no_grad():

            model._momentum_update_key_encoders()

            
            key_image_features = model.key_encode_image(image)
            key_text_features = model.key_encode_text(text)

            key_image_features = key_image_features / key_image_features.norm(dim=-1, keepdim=True)
            key_text_features = key_text_features / key_text_features.norm(dim=-1, keepdim=True)

            model._dequeue_and_enqueue(key_image_features,key_text_features)
        
        # Now the keys are the updated queue
        keys_for_this_batch = {"image_embed" : model.image_queue.to(device), "text_embed": model.text_queue.to(device)}
        
        
        # Zero your gradients for every batch!
        optimizer.zero_grad()
        
        #compute prediction for the batch
        output = model(image,text)
        
        
        #compute loss and its gradients
        loss = loss_fn(output,keys_for_this_batch)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        

        # Gather data and report
        count = batch["image"].size(0)
        loss_meter.update(loss, count)

        wandb.log({"loss": loss_meter.avg_loss, "lr" : get_lr(optimizer)  } )
        tqdm_object.set_postfix(train_loss=loss_meter.avg_loss.item())
        
        
    return loss_meter

In [7]:
from dataloader import get_dataloader
from tokenizer import get_tokenizer,get_feature_extractor
from losses import CLIPMoCOLoss, CLIPLoss
import itertools
from training import valid_one_epoch
from transformers import logging


In [8]:
logging.set_verbosity_error()

wandb.init(project="master_test_1",
           config={
               "batch_size": CFG.batch_size,
               "learning_rate": CFG.head_lr,
               "dataset": "flickr30k",
           },
           group="group_test",
           name="Moco")
tokenizer = get_tokenizer(CFG.text_model_name)
feature_extractor = get_feature_extractor(CFG.image_model_name)

dataloader_train = get_dataloader(tokenizer=tokenizer,feature_extractor=feature_extractor,batch_size=CFG.batch_size,shuffle=CFG.shuffle_train,num_workers=CFG.num_workers,split="train")
dataloader_valid = get_dataloader(tokenizer=tokenizer,feature_extractor=feature_extractor,batch_size=CFG.batch_size,shuffle=CFG.shuffle_train,num_workers=CFG.num_workers,split="val")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = CLIPProjMoco().to(device)
loss_train = CLIPMoCOLoss()
loss_valid = CLIPLoss()
if CFG.trainable == False:
        params = [
            {"params": itertools.chain(
                model.image_projection.parameters(), model.text_projection.parameters()
            ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
        ]
else: 
    params = [
        {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
        {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
        {"params": itertools.chain(
            model.image_projection.parameters(), model.text_projection.parameters()
        ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
    ]
optimizer = torch.optim.AdamW(params, weight_decay=0.)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=CFG.T_max)

best_loss = float('inf')

for epoch in range(CFG.epochs):
    
    print(f"Epoch: {epoch + 1}")
    model.train()
    train_loss = train_one_MOCO_epoch(model, loss_train, dataloader_train, optimizer,device)
    
    model.eval()

    with torch.no_grad():
        valid_loss = valid_one_epoch(model,loss_valid,dataloader_valid,device)

    if valid_loss.avg_loss < best_loss:
        best_loss = valid_loss.avg_loss
        torch.save(model.image_projection.state_dict(), "weights/img_proj_best.pt")
        torch.save(model.text_projection.state_dict(), "weights/text_proj_best.pt")
        #print("Saved Best Model!")
    

    
    lr_scheduler.step()


torch.save(model.image_projection.state_dict(), "weights/img_proj_last.pt")
torch.save(model.text_projection.state_dict(), "weights/text_proj_last.pt")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcarrelv[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using downloaded and verified file: ./flickr30k/flickr30k_train.json
Using downloaded and verified file: ./flickr30k/flickr30k_val.json
Epoch: 1


100%|██████████| 453/453 [06:41<00:00,  1.13it/s, train_loss=6.89]
100%|██████████| 15/15 [00:06<00:00,  2.32it/s, valid_loss=4.16]


Epoch: 2


100%|██████████| 453/453 [06:41<00:00,  1.13it/s, train_loss=6.92]
100%|██████████| 15/15 [00:06<00:00,  2.16it/s, valid_loss=4.16]
