In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from torch.utils.data import Dataset, DataLoader
from enum import Enum
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm
import os
import pickle
import sys
import argparse
import json
from typing import Tuple, Optional, Union
import pandas as pd

In [21]:
train_cap=pd.read_pickle("all_data/total_data.pkl")

In [25]:
train_cap

Unnamed: 0,caption,id,name,width,height,caption_y
0,Computed tomography scan in axial view showin...,ROCO_00002,PMC4083729_AMHSR-4-14-g002.jpg,704,704,1
1,Bacterial contamination occurred after comple...,ROCO_00003,PMC2837471_IJD2009-150251.001.jpg,763,745,1
2,The patient had residual paralysis of the han...,ROCO_00004,PMC2505281_11999_2007_30_Fig6_HTML.jpg,1258,737,1
3,Panoramic radiograph after immediate loading.\n,ROCO_00005,PMC3745845_IJD2013-683423.005.jpg,325,600,1
4,Plain abdomen x-ray: Multiple air levels at t...,ROCO_00007,PMC4917066_amjcaserep-17-301-g001.jpg,499,600,1
...,...,...,...,...,...,...
65445,Initial CT abdomen with contrast showing a di...,ROCO_81819,PMC3517833_CRIM.HEMATOLOGY2012-490438.001.jpg,484,600,1
65446,44-year-old male patient after surgical amput...,ROCO_81820,PMC5487234_rb-50-03-0190-g13.jpg,580,771,1
65447,Primary pulmonary tuberculosis in 18-year-old...,ROCO_81821,PMC2974222_kjr-11-612-g001.jpg,484,496,1
65448,"MRI brain with gadolinium, coronal view, show...",ROCO_81822,PMC3532764_AJNS-7-151-g002.jpg,591,518,1


In [24]:
train_cap[train_cap.name=='PMC2892771_yjbm_83_2_67_g01.jpg']

Unnamed: 0,caption,id,name,width,height,caption_y
6188,CT scan without contrast. Low-density lesion ...,ROCO_07731,PMC2892771_yjbm_83_2_67_g01.jpg,,,1


In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from torch.utils.data import Dataset, DataLoader
from enum import Enum
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup #AdamW to be deprecated in future
from tqdm import tqdm
import os
import pickle
import sys
import argparse
import json
from typing import Tuple, Optional, Union
import pandas as pd

class MappingType(Enum):
    MLP = 'mlp'
    Transformer = 'transformer'


class ClipROCODataset(Dataset):

    def __len__(self) -> int:
        return len(self.captions_tokens)

    def pad_tokens(self, item: int):
        tokens = self.captions_tokens[item]
        padding = self.max_seq_len - tokens.shape[0]
        if padding > 0:
            tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
            self.captions_tokens[item] = tokens
        elif padding < 0:
            tokens = tokens[:self.max_seq_len]
            self.captions_tokens[item] = tokens
        mask = tokens.ge(0)  # mask is zero where we out of sequence
        tokens[~mask] = 0
        mask = mask.float()
        mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0)  # adding prefix mask
        return tokens, mask

    def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]:
        tokens, mask = self.pad_tokens(item)
        print(item)
        print(self.caption2embedding[item])
        print(self.prefixes)
        prefix = self.prefixes[self.caption2embedding[item]]
        if self.normalize_prefix:
            prefix = prefix.float()
            prefix = prefix / prefix.norm(2, -1)
        return tokens, mask, prefix

    def __init__(self, data_path: str,  prefix_length: int, gpt2_type: str = "gpt2",
                 normalize_prefix=False):
        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        self.prefix_length = prefix_length
        self.normalize_prefix = normalize_prefix
        with open(data_path, 'rb') as f:
            all_data = pickle.load(f)
        print("Data size is %0d" % len(all_data["clip_embedding"]))
        #print(all_data.keys())
        sys.stdout.flush()
        self.prefixes = all_data["clip_embedding"]
        print(len(self.prefixes))
        captions_raw = all_data["captions"]
#         print(captions_raw[:5])
        self.image_ids = [caption["id"] for caption in captions_raw]
        self.captions = [caption['caption'] for caption in captions_raw]
        if os.path.isfile(f"{data_path[:-4]}_tokens.pkl"):
            with open(f"{data_path[:-4]}_tokens.pkl", 'rb') as f:
                self.captions_tokens, self.caption2embedding, self.max_seq_len = pickle.load(f)
        else:
            self.captions_tokens = []
            self.caption2embedding = []
            max_seq_len = 0
            for caption in captions_raw:
                self.captions_tokens.append(torch.tensor(self.tokenizer.encode(caption['caption']), dtype=torch.int64))
                self.caption2embedding.append(caption["clip_embedding"])
                max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0])
            # self.max_seq_len = max_seq_len
            print(max_seq_len)
            with open(f"{data_path[:-4]}_tokens.pkl", 'wb') as f:
                pickle.dump([self.captions_tokens, self.caption2embedding, max_seq_len], f)
        all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float()
        print(len(self.caption2embedding))
        self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max()))


class MLP(nn.Module):

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)


class MlpTransformer(nn.Module):
    def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
        super().__init__()
        out_d = out_d if out_d is not None else in_dim
        self.fc1 = nn.Linear(in_dim, h_dim)
        self.act = act
        self.fc2 = nn.Linear(h_dim, out_d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class MultiHeadAttention(nn.Module):

    def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim_self // num_heads
        self.scale = head_dim ** -0.5
        self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
        self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
        self.project = nn.Linear(dim_self, dim_self)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, y=None, mask=None):
        y = y if y is not None else x
        b, n, c = x.shape
        _, m, d = y.shape
        # b n h dh
        queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
        # b m 2 h dh
        keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
        keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
        attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
        if mask is not None:
            if mask.dim() == 2:
                mask = mask.unsqueeze(1)
            attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
        attention = attention.softmax(dim=2)
        out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
        out = self.project(out)
        return out, attention


class TransformerLayer(nn.Module):

    def forward_with_attention(self, x, y=None, mask=None):
        x_, attention = self.attn(self.norm1(x), y, mask)
        x = x + x_
        x = x + self.mlp(self.norm2(x))
        return x, attention

    def forward(self, x, y=None, mask=None):
        x = x + self.attn(self.norm1(x), y, mask)[0]
        x = x + self.mlp(self.norm2(x))
        return x

    def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
                 norm_layer: nn.Module = nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim_self)
        self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
        self.norm2 = norm_layer(dim_self)
        self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)


class Transformer(nn.Module):

    def forward_with_attention(self, x, y=None, mask=None):
        attentions = []
        for layer in self.layers:
            x, att = layer.forward_with_attention(x, y, mask)
            attentions.append(att)
        return x, attentions

    def forward(self, x, y=None, mask=None):
        for i, layer in enumerate(self.layers):
            if i % 2 == 0 and self.enc_dec: # cross
                x = layer(x, y)
            elif self.enc_dec:  # self
                x = layer(x, x, mask)
            else:  # self or cross
                x = layer(x, y, mask)
        return x

    def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
                 mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
        super(Transformer, self).__init__()
        dim_ref = dim_ref if dim_ref is not None else dim_self
        self.enc_dec = enc_dec
        if enc_dec:
            num_layers = num_layers * 2
        layers = []
        for i in range(num_layers):
            if i % 2 == 0 and enc_dec:  # cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            elif enc_dec:  # self
                layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
            else:  # self or cross
                layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
        self.layers = nn.ModuleList(layers)


class TransformerMapper(nn.Module):

    def forward(self, x):
        x = self.linear(x).view(x.shape[0], self.clip_length, -1)
        prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
        prefix = torch.cat((x, prefix), dim=1)
        out = self.transformer(prefix)[:, self.clip_length:]
        return out

    def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
        super(TransformerMapper, self).__init__()
        self.clip_length = clip_length
        self.transformer = Transformer(dim_embedding, 8, num_layers)
        self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
        self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)


class ClipCaptionModel(nn.Module):

    def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
        return out

    def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
                 num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        if mapping_type == MappingType.MLP:
            self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
                                     self.gpt_embedding_size * prefix_length))
        else:
            self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
                                                                     clip_length, num_layers)


class ClipCaptionPrefix(ClipCaptionModel):

    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()

    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self


def save_config(args: argparse.Namespace):
    config = {}
    for key, item in args._get_kwargs():
        config[key] = item
    out_path = os.path.join(args.out_dir, f"{args.prefix}.json")
    with open(out_path, 'w') as outfile:
        json.dump(config, outfile)


def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'):
    with open(config_path) as f:
        config = json.load(f)
    parser = argparse.ArgumentParser()
    parser.set_defaults(**config)
    args = parser.parse_args()
    if type(epoch_or_latest) is int:
        epoch_or_latest = f"-{epoch_or_latest:03d}"
    model_path = os.path.join(args.out_dir, f"{args.prefix}{epoch_or_latest}.pt")
    if args.only_prefix:
        model = ClipCaptionPrefix(args.prefix_length)
    else:
        model = ClipCaptionModel(args.prefix_length)
    if os.path.isfile(model_path):
        print(f"loading model from {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    else:
        print(f"{model_path} is not exist")
    return model, parser


def train(dataset: ClipROCODataset, model: ClipCaptionModel, args,
          lr: float = 2e-5, warmup_steps: int = 5000, output_dir: str = ".", output_prefix: str = ""):

    device = torch.device('cpu')
    batch_size = args.bs
    epochs = args.epochs
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    model = model.to(device)
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)
    train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
    )
    # save_config(args)
    for epoch in range(epochs):
        print(f">>> Training epoch {epoch}")
        sys.stdout.flush()
        progress = tqdm(total=len(train_dataloader), desc=output_prefix)
        for idx, (tokens, mask, prefix) in enumerate(train_dataloader):
            model.zero_grad()
            tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
            outputs = model(tokens, prefix, mask)
            logits = outputs.logits[:, dataset.prefix_length - 1: -1]
            loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            progress.set_postfix({"loss": loss.item()})
            progress.update()
            if (idx + 1) % 10000 == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(output_dir, f"{output_prefix}_latest.pt"),
                )
        progress.close()
        if epoch % args.save_every == 0 or epoch == epochs - 1:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
            )
    return model


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', default='clip/roco/ViT-B_32_train.pkl')
    parser.add_argument('--out_dir', default='./checkpoints')
    parser.add_argument('--prefix', default='roco_prefix', help='prefix for saved filenames')
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--save_every', type=int, default=1)
    parser.add_argument('--prefix_length', type=int, default=10)
    parser.add_argument('--prefix_length_clip', type=int, default=10)
    parser.add_argument('--bs', type=int, default=1)
    parser.add_argument('--only_prefix', dest='only_prefix', action='store_true')
    parser.add_argument('--mapping_type', type=str, default='mlp', help='mlp/transformer')
    parser.add_argument('--num_layers', type=int, default=8)
    parser.add_argument('--is_rn', dest='is_rn', action='store_true')
    parser.add_argument('--normalize_prefix', dest='normalize_prefix', action='store_true')
    args = parser.parse_args()
    prefix_length = args.prefix_length
    dataset = ClipROCODataset(args.data, prefix_length, normalize_prefix=args.normalize_prefix)
    prefix_dim = 640 if args.is_rn else 512
    args.mapping_type = {'mlp': MappingType.MLP, 'transformer': MappingType.Transformer}[args.mapping_type]
    if args.only_prefix:
        model = ClipCaptionPrefix(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim,
                                  num_layers=args.num_layers, mapping_type=args.mapping_type)
        print("Train only prefix")
    else:
        model = ClipCaptionModel(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim,
                                  num_layers=args.num_layers, mapping_type=args.mapping_type)
        print("Train both prefix and GPT")
        sys.stdout.flush()
    train(dataset, model, args, output_dir=args.out_dir, output_prefix=args.prefix)



In [47]:
prefix_dim=640

In [124]:
dataset = ClipROCODataset('/Users/simranmasand/Downloads/cs263_final/clip/roco/ViT-B_32_train_final.pkl',prefix_length=10)

Data size is 65419
65419
658
65419


In [76]:
with open('/Users/simranmasand/Downloads/cs263_final/clip/roco/ViT-B_32_train.pkl', 'rb') as f:
    all_data = pickle.load(f)
print("Data size is %0d" % len(all_data["clip_embedding"]))
#print(all_data.keys())
sys.stdout.flush()
prefixes = all_data["clip_embedding"]


Data size is 65419


In [89]:
prefixes


tensor([[ 0.4109, -0.1643,  0.0036,  ...,  0.4382, -0.2597, -0.0791],
        [ 0.4392,  0.2154,  0.2748,  ...,  0.6157, -0.1521,  0.5884],
        [ 0.2488,  0.1515,  0.1376,  ...,  0.4411,  0.1564,  0.3754],
        ...,
        [ 0.3539, -0.3620, -0.3011,  ...,  0.1995, -0.1858,  0.2059],
        [ 0.2768, -0.2366, -0.2561,  ...,  0.4807, -0.3081,  0.0027],
        [ 0.2766, -0.0282, -0.0473,  ...,  0.5679, -0.0079,  0.0456]])

In [77]:
prefixes.shape

torch.Size([65419, 512])

In [78]:
captions_raw = all_data["captions"]

In [88]:
captions_raw[0]

id                                                       ROCO_00002
name                                 PMC4083729_AMHSR-4-14-g002.jpg
caption            Computed tomography scan in axial view showin...
clip_embedding                                                    0
Name: 0, dtype: object

In [91]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [93]:
captions_tokens = []
caption2embedding = []
max_seq_len = 0
for caption in captions_raw:
    captions_tokens.append(torch.tensor(tokenizer.encode(caption['caption']), dtype=torch.int64))
    caption2embedding.append(caption["clip_embedding"])
    max_seq_len = max(max_seq_len, captions_tokens[-1].shape[0])
    print(max_seq_len)
            

20
25
67
67
67
70
70
70
70
70
80
80
80
80
80
80
80
80
80
80
80
80
80
80
80
80
80
80
80
80
80
80
80
80
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
89
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
94
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159
159


228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228
228


328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328


328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328


328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328


328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328


328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328


328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328
328


567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567


567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567


567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
567
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658
658


In [9]:
%pwd

'/Users/simranmasand/Downloads'

In [8]:
%cd ..

/Users/simranmasand/Downloads


In [11]:
data = pd.read_csv("all_data/test/radiologytestdata.csv") #is this the best we got
#     with open('.all_data/train/radiologytrain_caption.json', 'r') as f:
#         data = json.load(f)

In [15]:
to_delete = []
for i in tqdm(range(len(data))):
    d = data.iloc[i]
    img_id = d["name"]
    filename = f"all_data/test/images/{img_id}"
    if not os.path.isfile(filename):
        to_delete.append(img_id)

100%|██████████| 8179/8179 [00:00<00:00, 33064.25it/s]


In [16]:
to_delete

['PMC5241049_nihms839240f4.jpg',
 'PMC5357066_emss-71420-f001.jpg',
 'PMC4544285_anec0019-0193-f3.jpg']

In [17]:
len(to_delete)

3

In [120]:
data[data["name"]=='PMC2892771_yjbm_83_2_67_g01.jpg']

Unnamed: 0,id,name,caption
6188,ROCO_07731,PMC2892771_yjbm_83_2_67_g01.jpg,CT scan without contrast. Low-density lesion ...


In [19]:
updated_data = data[~data["name"].str.contains(pattern)]

In [18]:
pattern = '|'.join(to_delete)

In [20]:
updated_data.to_csv('all_data/test/radiologytestdata_final.csv')

In [21]:
len(updated_data)

8176

In [38]:
args = {"out_dir":"/CS263_final/checkpoints","epochs":3,"save_every":2,"prefix_length":10,"prefix_length_clip":10,"bs":1,
        "num_layers":10,"mapping_type" :{'mlp': MappingType.MLP, 'transformer': MappingType.Transformer}["mlp"],
       "output_prefix":"roco_prefix"}

In [35]:
args["prefix_length"]

10

In [50]:
model = ClipCaptionModel(args["prefix_length"], clip_length=args["prefix_length_clip"], prefix_size=prefix_dim,
num_layers=args["num_layers"], mapping_type=args["mapping_type"])

In [72]:
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True, drop_last=True)

In [73]:
i=0
while i <5:
    for idx, (tokens, mask, prefix) in enumerate(train_dataloader):        
        print(idx, (tokens.shape, mask.shape, prefix.shape))
        i+=1
    

18879
18884
0 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
55723
55749
1 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
20137
20144
2 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
37415
37432
3 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
45826
45847
4 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
24333
24346
5 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
23199
23211
6 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
51185
51210
7 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
2097
2097
8 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
8569
8571
9 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
46047
46068
10 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
54895
54921
11 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
6614
6616
12 (torch.Size([1, 2

7187
7189
105 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
45816
45837
106 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
26528
26541
107 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
2860
2860
108 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
54668
54694
109 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
52158
52183
110 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
12977
12980
111 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
58069
58096
112 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
5975
5975
113 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
7438
7440
114 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
50841
50866
115 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
19129
19134
116 (torch.Size([1, 285]), torch.Size([1, 295]), torch.Size([1, 512]))
15060
15064


IndexError: index 65429 is out of bounds for dimension 0 with size 65419

In [39]:
# parser = argparse.ArgumentParser()
# # parser.add_argument('--data', default='./data/coco/oscar_split_train.pkl')
# parser.add_argument('--out_dir', default="/CS263_final/checkpoints")
# parser.add_argument('--prefix', default='roco_prefix', help='prefix for saved filenames')
# parser.add_argument('--epochs', type=int, default=10)
# parser.add_argument('--save_every', type=int, default=1)
# parser.add_argument('--prefix_length', type=int, default=10)
# parser.add_argument('--prefix_length_clip', type=int, default=10)
# parser.add_argument('--bs', type=int, default=40)
# parser.add_argument('--only_prefix', dest='only_prefix', action='store_true')
# parser.add_argument('--mapping_type', type=str, default='mlp', help='mlp/transformer')
# parser.add_argument('--num_layers', type=int, default=8)
# parser.add_argument('--is_rn', dest='is_rn', action='store_true')
# parser.add_argument('--normalize_prefix', dest='normalize_prefix', action='store_true')
# args = parser.parse_args()
# prefix_length = args.prefix_length
# prefix_dim = 640 if args.is_rn else 512
# args.mapping_type = {'mlp': MappingType.MLP, 'transformer': MappingType.Transformer}[args.mapping_type]

In [138]:
prefix_dim = 512 #if args.is_rn else 512
model = ClipCaptionPrefix(10, clip_length=args["prefix_length_clip"], prefix_size=prefix_dim,num_layers=args["num_layers"], mapping_type=args["mapping_type"])
#train(dataset, model, args, output_dir=args.out_dir, output_prefix=args.prefix)



In [None]:
train(dataset, model, args, output_dir=args.out_dir, output_prefix=args.prefix)

In [129]:
args = {"out_dir":"./checkpoints","epochs":3,"save_every":2,"prefix_length":10,"prefix_length_clip":10,"bs":1,
        "num_layers":10,"mapping_type" :{'mlp': MappingType.MLP, 'transformer': MappingType.Transformer}["mlp"],
       "output_prefix":"roco_prefix"}

In [131]:
device = torch.device('cpu')
batch_size = args["bs"]
epochs = 3
if not os.path.exists(args["out_dir"]):
    os.makedirs(args["out_dir"])
model = model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=2e-5)
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=5000, num_training_steps=epochs * len(train_dataloader)
)
# 



In [None]:
prefix_dim = 512 #if args.is_rn else 512

In [139]:
output_prefix= args["output_prefix"]
for epoch in range(epochs):
    print(f">>> Training epoch {epoch}")
    sys.stdout.flush()
    progress = tqdm(total=len(train_dataloader), desc=output_prefix)
    for idx, (tokens, mask, prefix) in enumerate(train_dataloader):
        model.zero_grad()
        tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
        outputs = model(tokens, prefix, mask)
        logits = outputs.logits[:, dataset.prefix_length - 1: -1]
        loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        progress.set_postfix({"loss": loss.item()})
        progress.update()
        if (idx + 1) % 10 == 0:
            torch.save(
                model.state_dict(),
                os.path.join(args["out_dir"], f"{output_prefix}_latest.pt"),
            )
    progress.close()
    if epoch % args.save_every == 0 or epoch == epochs - 1:
        torch.save(
            model.state_dict(),
            os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
        )

>>> Training epoch 0



roco_prefix:   0%|          | 0/65419 [02:20<?, ?it/s][A

5026
5026
tensor([[ 0.4109, -0.1643,  0.0036,  ...,  0.4382, -0.2597, -0.0791],
        [ 0.4392,  0.2154,  0.2748,  ...,  0.6157, -0.1521,  0.5884],
        [ 0.2488,  0.1515,  0.1376,  ...,  0.4411,  0.1564,  0.3754],
        ...,
        [ 0.3539, -0.3620, -0.3011,  ...,  0.1995, -0.1858,  0.2059],
        [ 0.2768, -0.2366, -0.2561,  ...,  0.4807, -0.3081,  0.0027],
        [ 0.2766, -0.0282, -0.0473,  ...,  0.5679, -0.0079,  0.0456]])




roco_prefix:   0%|          | 0/65419 [00:02<?, ?it/s, loss=7.91][A
roco_prefix:   0%|          | 1/65419 [00:02<43:26:54,  2.39s/it, loss=7.91][A

24462
24475
tensor([[ 0.4109, -0.1643,  0.0036,  ...,  0.4382, -0.2597, -0.0791],
        [ 0.4392,  0.2154,  0.2748,  ...,  0.6157, -0.1521,  0.5884],
        [ 0.2488,  0.1515,  0.1376,  ...,  0.4411,  0.1564,  0.3754],
        ...,
        [ 0.3539, -0.3620, -0.3011,  ...,  0.1995, -0.1858,  0.2059],
        [ 0.2768, -0.2366, -0.2561,  ...,  0.4807, -0.3081,  0.0027],
        [ 0.2766, -0.0282, -0.0473,  ...,  0.5679, -0.0079,  0.0456]])



roco_prefix:   0%|          | 1/65419 [00:04<43:26:54,  2.39s/it, loss=4.83][A
roco_prefix:   0%|          | 2/65419 [00:04<41:56:07,  2.31s/it, loss=4.83][A

18895
18900
tensor([[ 0.4109, -0.1643,  0.0036,  ...,  0.4382, -0.2597, -0.0791],
        [ 0.4392,  0.2154,  0.2748,  ...,  0.6157, -0.1521,  0.5884],
        [ 0.2488,  0.1515,  0.1376,  ...,  0.4411,  0.1564,  0.3754],
        ...,
        [ 0.3539, -0.3620, -0.3011,  ...,  0.1995, -0.1858,  0.2059],
        [ 0.2768, -0.2366, -0.2561,  ...,  0.4807, -0.3081,  0.0027],
        [ 0.2766, -0.0282, -0.0473,  ...,  0.5679, -0.0079,  0.0456]])



roco_prefix:   0%|          | 2/65419 [00:06<41:56:07,  2.31s/it, loss=7.32][A
roco_prefix:   0%|          | 3/65419 [00:06<39:20:00,  2.16s/it, loss=7.32][A

45109
45130
tensor([[ 0.4109, -0.1643,  0.0036,  ...,  0.4382, -0.2597, -0.0791],
        [ 0.4392,  0.2154,  0.2748,  ...,  0.6157, -0.1521,  0.5884],
        [ 0.2488,  0.1515,  0.1376,  ...,  0.4411,  0.1564,  0.3754],
        ...,
        [ 0.3539, -0.3620, -0.3011,  ...,  0.1995, -0.1858,  0.2059],
        [ 0.2768, -0.2366, -0.2561,  ...,  0.4807, -0.3081,  0.0027],
        [ 0.2766, -0.0282, -0.0473,  ...,  0.5679, -0.0079,  0.0456]])



roco_prefix:   0%|          | 3/65419 [00:08<39:20:00,  2.16s/it, loss=7.82][A
roco_prefix:   0%|          | 4/65419 [00:08<38:50:59,  2.14s/it, loss=7.82][A

34054
34071
tensor([[ 0.4109, -0.1643,  0.0036,  ...,  0.4382, -0.2597, -0.0791],
        [ 0.4392,  0.2154,  0.2748,  ...,  0.6157, -0.1521,  0.5884],
        [ 0.2488,  0.1515,  0.1376,  ...,  0.4411,  0.1564,  0.3754],
        ...,
        [ 0.3539, -0.3620, -0.3011,  ...,  0.1995, -0.1858,  0.2059],
        [ 0.2768, -0.2366, -0.2561,  ...,  0.4807, -0.3081,  0.0027],
        [ 0.2766, -0.0282, -0.0473,  ...,  0.5679, -0.0079,  0.0456]])



roco_prefix:   0%|          | 4/65419 [00:10<38:50:59,  2.14s/it, loss=6.24][A
roco_prefix:   0%|          | 5/65419 [00:10<36:37:05,  2.02s/it, loss=6.24][A

52805
52830
tensor([[ 0.4109, -0.1643,  0.0036,  ...,  0.4382, -0.2597, -0.0791],
        [ 0.4392,  0.2154,  0.2748,  ...,  0.6157, -0.1521,  0.5884],
        [ 0.2488,  0.1515,  0.1376,  ...,  0.4411,  0.1564,  0.3754],
        ...,
        [ 0.3539, -0.3620, -0.3011,  ...,  0.1995, -0.1858,  0.2059],
        [ 0.2768, -0.2366, -0.2561,  ...,  0.4807, -0.3081,  0.0027],
        [ 0.2766, -0.0282, -0.0473,  ...,  0.5679, -0.0079,  0.0456]])



roco_prefix:   0%|          | 5/65419 [00:12<36:37:05,  2.02s/it, loss=5.69][A
roco_prefix:   0%|          | 6/65419 [00:12<38:34:33,  2.12s/it, loss=5.69][A

38250
38267
tensor([[ 0.4109, -0.1643,  0.0036,  ...,  0.4382, -0.2597, -0.0791],
        [ 0.4392,  0.2154,  0.2748,  ...,  0.6157, -0.1521,  0.5884],
        [ 0.2488,  0.1515,  0.1376,  ...,  0.4411,  0.1564,  0.3754],
        ...,
        [ 0.3539, -0.3620, -0.3011,  ...,  0.1995, -0.1858,  0.2059],
        [ 0.2768, -0.2366, -0.2561,  ...,  0.4807, -0.3081,  0.0027],
        [ 0.2766, -0.0282, -0.0473,  ...,  0.5679, -0.0079,  0.0456]])



roco_prefix:   0%|          | 6/65419 [00:15<38:34:33,  2.12s/it, loss=7.71][A
roco_prefix:   0%|          | 7/65419 [00:15<38:48:29,  2.14s/it, loss=7.71][A

44014
44034
tensor([[ 0.4109, -0.1643,  0.0036,  ...,  0.4382, -0.2597, -0.0791],
        [ 0.4392,  0.2154,  0.2748,  ...,  0.6157, -0.1521,  0.5884],
        [ 0.2488,  0.1515,  0.1376,  ...,  0.4411,  0.1564,  0.3754],
        ...,
        [ 0.3539, -0.3620, -0.3011,  ...,  0.1995, -0.1858,  0.2059],
        [ 0.2768, -0.2366, -0.2561,  ...,  0.4807, -0.3081,  0.0027],
        [ 0.2766, -0.0282, -0.0473,  ...,  0.5679, -0.0079,  0.0456]])



roco_prefix:   0%|          | 7/65419 [00:17<38:48:29,  2.14s/it, loss=4.74][A
roco_prefix:   0%|          | 8/65419 [00:17<39:30:15,  2.17s/it, loss=4.74][A

11786
11789
tensor([[ 0.4109, -0.1643,  0.0036,  ...,  0.4382, -0.2597, -0.0791],
        [ 0.4392,  0.2154,  0.2748,  ...,  0.6157, -0.1521,  0.5884],
        [ 0.2488,  0.1515,  0.1376,  ...,  0.4411,  0.1564,  0.3754],
        ...,
        [ 0.3539, -0.3620, -0.3011,  ...,  0.1995, -0.1858,  0.2059],
        [ 0.2768, -0.2366, -0.2561,  ...,  0.4807, -0.3081,  0.0027],
        [ 0.2766, -0.0282, -0.0473,  ...,  0.5679, -0.0079,  0.0456]])



roco_prefix:   0%|          | 8/65419 [00:19<39:30:15,  2.17s/it, loss=5.17][A
roco_prefix:   0%|          | 9/65419 [00:19<38:22:17,  2.11s/it, loss=5.17][A

53483
53508
tensor([[ 0.4109, -0.1643,  0.0036,  ...,  0.4382, -0.2597, -0.0791],
        [ 0.4392,  0.2154,  0.2748,  ...,  0.6157, -0.1521,  0.5884],
        [ 0.2488,  0.1515,  0.1376,  ...,  0.4411,  0.1564,  0.3754],
        ...,
        [ 0.3539, -0.3620, -0.3011,  ...,  0.1995, -0.1858,  0.2059],
        [ 0.2768, -0.2366, -0.2561,  ...,  0.4807, -0.3081,  0.0027],
        [ 0.2766, -0.0282, -0.0473,  ...,  0.5679, -0.0079,  0.0456]])



roco_prefix:   0%|          | 9/65419 [00:20<38:22:17,  2.11s/it, loss=7.63][A
roco_prefix:   0%|          | 10/65419 [00:20<32:45:14,  1.80s/it, loss=7.63][A

NameError: name 'output_dir' is not defined

In [4]:
#@title Imports

import clip
import os
from torch import nn
import numpy as np
import torch
import torch.nn.functional as nnf
import sys
from typing import Tuple, List, Union, Optional
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
# from google.colab import files
import skimage.io as io
import PIL.Image
from IPython.display import Image 


N = type(None)
V = np.array
ARRAY = np.ndarray
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
VS = Union[Tuple[V, ...], List[V]]
VN = Union[V, N]
VNS = Union[VS, N]
T = torch.Tensor
TS = Union[Tuple[T, ...], List[T]]
TN = Optional[T]
TNS = Union[Tuple[TN, ...], List[TN]]
TSN = Optional[TS]
TA = Union[T, ARRAY]


D = torch.device
CPU = torch.device('cpu')


def get_device(device_id: int) -> D:
    if not torch.cuda.is_available():
        return CPU
    device_id = min(torch.cuda.device_count() - 1, device_id)
    return torch.device(f'cuda:{device_id}')


CUDA = get_device

current_directory = os.getcwd()
save_path = os.path.join(os.path.dirname(current_directory), "pretrained_models")
os.makedirs(save_path, exist_ok=True)
model_path = os.path.join(save_path, 'model_wieghts.pt')


In [5]:
current_directory

'/Users/simranmasand/Downloads'

In [6]:
#@title Model

class MLP(nn.Module):

    def forward(self, x: T) -> T:
        return self.model(x)

    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) -1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)


class ClipCaptionModel(nn.Module):

    #@functools.lru_cache #FIXME
    def get_dummy_token(self, batch_size: int, device: D) -> T:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
        #print(embedding_text.size()) #torch.Size([5, 67, 768])
        #print(prefix_projections.size()) #torch.Size([5, 1, 768])
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
        return out

    def __init__(self, prefix_length: int, prefix_size: int = 512):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        if prefix_length > 10:  # not enough memory
            self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
        else:
            self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))


class ClipCaptionPrefix(ClipCaptionModel):

    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()

    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self

In [7]:
#@title Caption prediction

def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
                  entry_length=67, temperature=1., stop_token: str = '.'):

    model.eval()
    stop_token_index = tokenizer.encode(stop_token)[0]
    tokens = None
    scores = None
    device = next(model.parameters()).device
    seq_lengths = torch.ones(beam_size, device=device)
    is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
    with torch.no_grad():
        if embed is not None:
            generated = embed
        else:
            if tokens is None:
                tokens = torch.tensor(tokenizer.encode(prompt))
                tokens = tokens.unsqueeze(0).to(device)
                generated = model.gpt.transformer.wte(tokens)
        for i in range(entry_length):
            outputs = model.gpt(inputs_embeds=generated)
            logits = outputs.logits
            logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
            logits = logits.softmax(-1).log()
            if scores is None:
                scores, next_tokens = logits.topk(beam_size, -1)
                generated = generated.expand(beam_size, *generated.shape[1:])
                next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
                if tokens is None:
                    tokens = next_tokens
                else:
                    tokens = tokens.expand(beam_size, *tokens.shape[1:])
                    tokens = torch.cat((tokens, next_tokens), dim=1)
            else:
                logits[is_stopped] = -float(np.inf)
                logits[is_stopped, 0] = 0
                scores_sum = scores[:, None] + logits
                seq_lengths[~is_stopped] += 1
                scores_sum_average = scores_sum / seq_lengths[:, None]
                scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
                next_tokens_source = next_tokens // scores_sum.shape[1]
                seq_lengths = seq_lengths[next_tokens_source]
                next_tokens = next_tokens % scores_sum.shape[1]
                next_tokens = next_tokens.unsqueeze(1)
                tokens = tokens[next_tokens_source]
                tokens = torch.cat((tokens, next_tokens), dim=1)
                generated = generated[next_tokens_source]
                scores = scores_sum_average * seq_lengths
                is_stopped = is_stopped[next_tokens_source]
            next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
            generated = torch.cat((generated, next_token_embed), dim=1)
            is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
            if is_stopped.all():
                break
    scores = scores / seq_lengths
    output_list = tokens.cpu().numpy()
    output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
    order = scores.argsort(descending=True)
    output_texts = [output_texts[i] for i in order]
    return output_texts


def generate2(
        model,
        tokenizer,
        tokens=None,
        prompt=None,
        embed=None,
        entry_count=1,
        entry_length=67,  # maximum number of words
        top_p=0.8,
        temperature=1.,
        stop_token: str = '.',
):
    model.eval()
    generated_num = 0
    generated_list = []
    stop_token_index = tokenizer.encode(stop_token)[0]
    filter_value = -float("Inf")
    device = next(model.parameters()).device

    with torch.no_grad():

        for entry_idx in trange(entry_count):
            if embed is not None:
                generated = embed
            else:
                if tokens is None:
                    tokens = torch.tensor(tokenizer.encode(prompt))
                    tokens = tokens.unsqueeze(0).to(device)

                generated = model.gpt.transformer.wte(tokens)

            for i in range(entry_length):

                outputs = model.gpt(inputs_embeds=generated)
                logits = outputs.logits
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                                                    ..., :-1
                                                    ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value
                next_token = torch.argmax(logits, -1).unsqueeze(0)
                next_token_embed = model.gpt.transformer.wte(next_token)
                if tokens is None:
                    tokens = next_token
                else:
                    tokens = torch.cat((tokens, next_token), dim=1)
                generated = torch.cat((generated, next_token_embed), dim=1)
                if stop_token_index == next_token.item():
                    break

            output_list = list(tokens.squeeze().cpu().numpy())
            output_text = tokenizer.decode(output_list)
            generated_list.append(output_text)

    return generated_list[0]

In [8]:
#@title GPU/CPU


is_gpu = True #@param {type:"boolean"}  


In [9]:
#@title CLIP model + GPT2 tokenizer

device = CUDA(0) if is_gpu else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

100%|███████████████████████████████████████| 338M/338M [00:11<00:00, 31.4MiB/s]


In [10]:
#@title Load model weights


prefix_length = 10

model = ClipCaptionModel(prefix_length)

model.load_state_dict(torch.load(model_path, map_location=CPU)) 

model = model.eval() 
device = CUDA(0) if is_gpu else "cpu"
model = model.to(device)


Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

FileNotFoundError: [Errno 2] No such file or directory: '/Users/simranmasand/pretrained_models/model_wieghts.pt'

In [141]:
import platform
platform.platform()
# [GOOD] >> macOS-12.4-arm64-arm-64bit
# [BAD]  >> macOS-11.8-x86_64-i386-64bit

'macOS-13.3.1-arm64-arm-64bit'

In [142]:
torch.has_mps

True