In [None]:
#!/usr/bin/env python

'''
RadCLIP Mapper Model Training Code
'''

__author__ = "Christoper Alexander"
__copyright__ = "Copyright 2023"
__credits__ = ["Andrew D'Amico", "Christoper Alexander", "Katya Nosulko", "Vivek Chamala", "Matthew Conger"]
__license__ = ""
__version__ = "0.0.1"
__maintainer__ = "Andrew Damico"
__email__ = "andrew.damico@u.northwestern.edu"

In [1]:
import os
import pickle
import sys
from enum import Enum
from typing import Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
GPT_CHECKPOINT = "gptmedium_10k"

In [3]:
class MappingType(Enum):
    MLP = 'mlp'
    Transformer = 'transformer'

In [4]:
class ClipCocoDataset(Dataset):
    def __init__(self, data_path: str, prefix_length: int, gpt2_type: str = GPT_CHECKPOINT,
                 normalize_prefix=False):
        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        self.prefix_length = prefix_length
        self.normalize_prefix = normalize_prefix
        with open(data_path, 'rb') as f:
            all_data = pickle.load(f)
        print("Data size is %0d" % len(all_data["clip_embedding"]))
        sys.stdout.flush()
        self.prefixes = all_data["clip_embedding"]
        captions_raw = all_data["captions"]
        # self.image_ids = [caption["image_id"] for caption in captions_raw]
        self.captions = [caption['caption'] for caption in captions_raw]

        self.captions_tokens = []
        self.caption2embedding = []
        max_seq_len = 0
        for caption in captions_raw:
            self.captions_tokens.append(
                torch.tensor(self.tokenizer.encode(caption['caption']), dtype=torch.int64)
            )
            self.caption2embedding.append(caption["clip_embedding"])
            max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0])
        self.max_seq_len = max_seq_len

    def __len__(self) -> int:
        return len(self.captions_tokens)

    def pad_tokens(self, item: int):
        tokens = self.captions_tokens[item]
        padding = self.max_seq_len - tokens.shape[0]
        if padding > 0:
            tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
            self.captions_tokens[item] = tokens
        elif padding < 0:
            tokens = tokens[:self.max_seq_len]
            self.captions_tokens[item] = tokens
        mask = tokens.ge(0)  # mask is zero where we out of sequence
        tokens[~mask] = 0
        mask = mask.float()
        mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0)  # adding prefix mask
        return tokens, mask

    def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]:
        tokens, mask = self.pad_tokens(item)
        prefix = self.prefixes[self.caption2embedding[item]]
        if self.normalize_prefix:
            prefix = prefix.float()
            prefix = prefix / prefix.norm(2, -1)
        return tokens, mask, prefix

In [5]:
class MLP(nn.Module):
    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


class MlpTransformer(nn.Module):
    def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
        super().__init__()
        out_d = out_d if out_d is not None else in_dim
        self.fc1 = nn.Linear(in_dim, h_dim)
        self.act = act
        self.fc2 = nn.Linear(h_dim, out_d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [6]:
class MultiHeadAttention(nn.Module):

    def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim_self // num_heads
        self.scale = head_dim ** -0.5
        self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
        self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
        self.project = nn.Linear(dim_self, dim_self)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y=None, mask=None):
        y = y if y is not None else x
        b, n, c = x.shape
        _, m, d = y.shape
        # b n h dh
        queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
        # b m 2 h dh
        keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
        keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
        attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(1)
            attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
        attention = attention.softmax(dim=2)
        out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
        out = self.project(out)
        return out, attention

In [7]:
class TransformerLayer(nn.Module):
    def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
                 norm_layer: nn.Module = nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim_self)
        self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
        self.norm2 = norm_layer(dim_self)
        self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)

    def forward_with_attention(self, x, y=None, mask=None):
        x_, attention = self.attn(self.norm1(x), y, mask)
        x = x + x_
        x = x + self.mlp(self.norm2(x))
        return x, attention

    def forward(self, x, y=None, mask=None):
        x = x + self.attn(self.norm1(x), y, mask)[0]
        x = x + self.mlp(self.norm2(x))
        return x

In [8]:
class Transformer(nn.Module):
    def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
                 mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
        super(Transformer, self).__init__()
        dim_ref = dim_ref if dim_ref is not None else dim_self
        self.enc_dec = enc_dec
        if enc_dec:
            num_layers = num_layers * 2
        layers = []
        for i in range(num_layers):
            if i % 2 == 0 and enc_dec:  # cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            elif enc_dec:  # self
                layers.append(
                    TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            else:  # self or cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
        self.layers = nn.ModuleList(layers)

    def forward_with_attention(self, x, y=None, mask=None):
        attentions = []
        for layer in self.layers:
            x, att = layer.forward_with_attention(x, y, mask)
            attentions.append(att)
        return x, attentions

    def forward(self, x, y=None, mask=None):
        for i, layer in enumerate(self.layers):
            if i % 2 == 0 and self.enc_dec:  # cross
                x = layer(x, y)
            elif self.enc_dec:  # self
                x = layer(x, x, mask)
            else:  # self or cross
                x = layer(x, y, mask)
        return x

In [9]:
class TransformerMapper(nn.Module):
    def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
        super(TransformerMapper, self).__init__()
        self.clip_length = clip_length
        self.transformer = Transformer(dim_embedding, 8, num_layers)
        self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
        self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)

    def forward(self, x):
        x = self.linear(x).view(x.shape[0], self.clip_length, -1)
        prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
        prefix = torch.cat((x, prefix), dim=1)
        out = self.transformer(prefix)[:, self.clip_length:]
        return out

In [10]:
class ClipCaptionModel(nn.Module):
    def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
                 num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained(GPT_CHECKPOINT)
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        if mapping_type == MappingType.MLP:
            self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
                                     self.gpt_embedding_size * prefix_length))
        else:
            self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
                                                  clip_length, num_layers)

    def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
        return out

In [11]:
def train(
        train_dataset: ClipCocoDataset,
        eval_dataset: ClipCocoDataset,
        model: ClipCaptionModel,
        batch_size,
        epochs,
        lr: float = 2e-5,
        warmup_steps: int = 5000,
        output_dir: str = ".",
        output_prefix: str = "",
        accum_iter: int = 4
):
    device = torch.device('cuda')
    model = model.to(device)

    optimizer = AdamW(model.parameters(), lr=lr)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
    )
    all_train_losses = {}
    all_eval_losses = {}
    for epoch in range(epochs):
        print(f">>> Training epoch {epoch}")
        model.train()
        sys.stdout.flush()
        progress = tqdm(total=len(train_dataloader), desc=output_prefix)
        train_losses = []
        for idx, (tokens, mask, prefix) in enumerate(train_dataloader):
            model.zero_grad()
            tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
            outputs = model(tokens, prefix, mask)
            logits = outputs.logits[:, train_dataset.prefix_length - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)
            loss.backward()

            # weights update
            # if ((idx + 1) % accum_iter == 0) or (idx + 1 == len(train_dataloader)):
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            train_loss = loss.item()
            progress.set_postfix({"loss": train_loss})
            train_losses.append(train_loss)
            progress.update()
            if (idx + 1) % 10000 == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f"{output_prefix}_latest.pt"),
                )
        progress.close()
        # print(f"Total train loss for epoch {epoch} is {accum_train_loss}")
        if epoch % 1 == 0 or epoch == epochs - 1:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
            )

        # evaluation
        model.eval()
        sys.stdout.flush()
        progress = tqdm(total=len(eval_dataloader), desc=output_prefix)
        eval_losses = []
        for idx, (tokens, mask, prefix) in enumerate(eval_dataloader):
            tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
            outputs = model(tokens, prefix, mask)
            logits = outputs.logits[:, eval_dataset.prefix_length - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)
            eval_loss = loss.item()
            progress.set_postfix({"loss": eval_loss})
            eval_losses.append(eval_loss)
            progress.update()

        progress.close()
        all_train_losses[f"epoch_{epoch}"] = train_losses.copy()
        all_eval_losses[f"epoch_{epoch}"] = eval_losses.copy()
        # print(f"Total eval loss for epoch {epoch} is {accum_eval_loss}")

    return model, all_train_losses, all_eval_losses

In [12]:
prefix_length = 10
train_dataset = ClipCocoDataset(
    "radclip_transformer_train_new.p",
    prefix_length,
    gpt2_type=GPT_CHECKPOINT
)

Data size is 71866


In [13]:
eval_dataset = ClipCocoDataset(
    "radclip_transformer_test_new.p",
    prefix_length,
    gpt2_type=GPT_CHECKPOINT
)

Data size is 3783


In [14]:
prefix_dim = 512

In [15]:
model = ClipCaptionModel(
    prefix_length, clip_length=10, prefix_size=prefix_dim, num_layers=8, mapping_type=MappingType.Transformer
)

In [16]:
sys.stdout.flush()

In [17]:
model, tloss, eloss = train(train_dataset, eval_dataset, model, 14, 4, output_prefix="new_radclip_gptmed")

>>> Training epoch 0


new_radclip_gptmed:  44%|████▍     | 13451/30697 [1:47:57<2:18:15,  2.08it/s, loss=1.13]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

new_radclip_gptmed:  69%|██████▊   | 21104/30697 [2:49:39<1:17:06,  2.07it/s, loss=1.27] IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

new_radclip_gptmed: 100%|██████████| 30697/30697 [4:06:54<00:00,  2.07it/s, loss=1.44] 
new_radclip_gptmed: 100%|██████████| 1621/1621 [02:13<00:00, 12.10it/s, loss=1.47] 

>>> Training epoch 1



new_radclip_gptmed:  18%|█▊        | 5444/30697 [43:34<3:21:58,  2.08it/s, loss=1.32] IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

new_radclip_gptmed:  95%|█████████▍| 29109/30697 [3:53:56<12:44,  2.08it/s, loss=1.45]   IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

new_radclip_gptmed:  15%|█▌        | 4750/30697 [38:02<3:27:28,  2.08it/s, loss=1.14] IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to

In [18]:
torch.save(model.state_dict(), "clip_caption_model_transformer_10k_new")

In [21]:
np.mean(eloss["epoch_2"])

1.4610324561559147

In [20]:
np.mean(eloss["epoch_1"])

1.4803535490033068

In [22]:
tloss["epoch_2"]

[1.3629426956176758,
 1.2069710493087769,
 1.8457428216934204,
 1.8556292057037354,
 1.4643269777297974,
 1.043078899383545,
 1.5649605989456177,
 1.338426947593689,
 1.1791372299194336,
 1.363163709640503,
 1.669366478919983,
 1.3977620601654053,
 1.6598552465438843,
 1.3277424573898315,
 1.4018137454986572,
 1.5364246368408203,
 1.4133042097091675,
 1.5706279277801514,
 1.4096860885620117,
 1.4590132236480713,
 1.4180641174316406,
 1.3646734952926636,
 1.461068868637085,
 1.094805121421814,
 1.3705743551254272,
 1.1924760341644287,
 1.5594137907028198,
 1.4929217100143433,
 1.0139890909194946,
 1.6037801504135132,
 1.9003727436065674,
 1.5881577730178833,
 1.5539162158966064,
 1.5637794733047485,
 1.5678178071975708,
 1.6485785245895386,
 1.388034701347351,
 1.754345178604126,
 1.1583211421966553,
 1.677202582359314,
 1.9811991453170776,
 1.3265657424926758,
 1.5462720394134521,
 1.866016149520874,
 1.4394484758377075,
 1.4259629249572754,
 1.3053988218307495,
 1.4140522480010986,
 1

In [23]:
type(model)

__main__.ClipCaptionModel