# Imports

In [None]:
import transformers
from transformers import CLIPConfig, CLIPModel, CLIPProcessor, CLIPImageProcessor, CLIPTokenizerFast
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
import random
import math
import scipy.io as sio
import nibabel as nib
from pathlib import Path
from gensim.models import Word2Vec
import re

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# Load word and fMRI data

### Load 3d fMRI data

In [None]:
NUM_SUBJS = 8
subjects_fmri = [] #stores all 8 subject fmri np arrays

fMRI_folder = Path('./doi_10.5061_dryad.gt413__v1')
assert fMRI_folder.exists(), f"Foldder: {fMRI_folder} does not exist."

for subj_id in range(8):
#     fmri_file_name = str(subj_id) + '_masked_2d.npy'
#     fmri = np.load(fMRI_folder / fmri_file_name)  
    fmri_file_name = str(subj_id) + '_smooth_nifti_4d.nii'
    fmri = nib.load(fMRI_folder / fmri_file_name)
    fmri = np.array(fmri.dataobj)
    assert isinstance(fmri, np.ndarray), f"Imported fmri_scan for subject {subj_id} is not of type numpy.ndarray"
    assert(fmri.ndim) == 4, f"Imported fmri_scan for subject {subj_id} is not 4 dimensional"
    subjects_fmri.append(fmri)

# Load words

In [None]:
feature_matrix = np.zeros((5176,195)) #stores the feature vectors as a row for each word
feature_names = [] #stores the names of all features in order
feature_types = {} #stores the types of features and all the names of the features for each type

features = sio.loadmat(fMRI_folder / 'story_features.mat')
feature_count = 0
for feature_type in features['features'][0]:
    feature_types[feature_type[0][0]] = []
    if isinstance(feature_type[1][0], str):
        feature_types[feature_type[0][0]].append(feature_type[1][0])
        feature_names.append(feature_type[1][0])
    else:
        for feature in feature_type[1][0]:
            feature_types[feature_type[0][0]].append(feature[0])
            feature_names.append(feature[0])
    feature_matrix[:, feature_count:feature_count+feature_type[2].shape[1]] = feature_type[2] #adds the (5176xN) feature values to the feature matrix for the current feature group
    feature_count += feature_type[2].shape[1]

In [None]:
words_info = [] #stores tuples of (word, time, features) sorted by time appeared

mat_file = fMRI_folder / 'subject_1.mat' #only looks at the first subject file, somewhere it said all the timings were the same so this should be safe
mat_contents = sio.loadmat(mat_file)
for count, row in enumerate(mat_contents['words'][0]):
    word_value = row[0][0][0][0]
    time = row[1][0][0]
    word_tuple = (word_value, time, feature_matrix[count,:])
    words_info.append(word_tuple)

In [None]:
chapter_nine_text = ""
for row in mat_contents['words'][0]:
    chapter_nine_text += row[0][0][0][0] + " "
print(chapter_nine_text)

### Align fMRI scans with sets of 4 words

In [None]:
subjects_samples = [[] for i in range(NUM_SUBJS)] #stores lists of all the samples for each subject

word_count = 0
while word_count < len(words_info) - 8:
    #gets the 4 input words, and the 4 consecutive words while verifying they were read in sequence
    scan_words = []
    start_time = words_info[word_count][1]
    in_sequence = True #tracks if the words are in sequence or not
    for i in range(8):
        word_info = words_info[word_count + i]
        if word_info[1] != start_time + 0.5*i:
            #if some of the words are not in sequence, skip forward 1 word after innter loop
            in_sequence = False
        scan_words.append(word_info[0])
    if not in_sequence:
        word_count +=1
        continue
    fmri_time = start_time + 2 #effect of reading words is assumed to start 2 seconds after and end 8 seconds after
    fmri_index = fmri_time//2 #since a scan happens every two seconds, the index is the time divided by 2
    if not isinstance(fmri_index, np.int32):
        #if the first word is not aligned with the fmri scan (i.e. its not the first word in a TR)
        word_count += 1
        continue
    for count, subject in enumerate(subjects_fmri):
        #adds tuple of (fmri_scan, four words)
        subjects_samples[count].append((subject[:,:,:,fmri_index+2], scan_words[0:4]))
    print("Created sample:")
    print("\tScan time:", str(start_time))
    print("\tInput words:", str(scan_words[0:4]))
    #if successful, skip forward to the next set of 4 words
    word_count += 4

print("Total number of samples:", str(len(subjects_samples[0])))

# OpenAI CLIP

### Tokenizer

In [None]:
import gzip
import html
import os
from functools import lru_cache

import ftfy
import regex as re


@lru_cache()
def default_bpe():
    return os.path.join(os.path.dirname(os.path.abspath("./doi_10.5061_dryad.gt413__v1")), "bpe_simple_vocab_16e6.txt.gz")


@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text


class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = default_bpe()):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)

        if not pairs:
            return token+'</w>'

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text

### Tokenize

In [None]:
import hashlib
import os
import urllib
import warnings
from typing import Any, Union, List
from pkg_resources import packaging

_tokenizer = SimpleTokenizer()

def tokenize(texts: Union[str, List[str]], context_length: int = 16, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
    """
    Returns the tokenized representation of given input string(s)

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize

    context_length : int
        The context length to use; all CLIP models use 77 as the context length

    truncate: bool
        Whether to truncate the text in case its encoding is longer than the context length

    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
    We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
    """
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<|startoftext|>"]
    eot_token = _tokenizer.encoder["<|endoftext|>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
    else:
        result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            if truncate:
                tokens = tokens[:context_length]
                tokens[-1] = eot_token
            else:
                raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result

### Model

In [None]:
from collections import OrderedDict
from typing import Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()

        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
        self.conv1 = nn.Conv3d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv3d(planes, planes, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.relu2 = nn.ReLU(inplace=True)

        self.avgpool = nn.AvgPool3d(stride) if stride > 1 else nn.Identity()

        self.conv3 = nn.Conv3d(planes, planes * self.expansion, 1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu3 = nn.ReLU(inplace=True)

        self.downsample = None
        self.stride = stride

        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool3d(stride)),
                ("0", nn.Conv3d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm3d(planes * self.expansion))
            ]))

    def forward(self, x: torch.Tensor):
        identity = x

        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.relu2(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu3(out)
        return out


# class AttentionPool2d(nn.Module):
#     def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
#         super().__init__()
#         self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
#         self.k_proj = nn.Linear(embed_dim, embed_dim)
#         self.q_proj = nn.Linear(embed_dim, embed_dim)
#         self.v_proj = nn.Linear(embed_dim, embed_dim)
#         self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
#         self.num_heads = num_heads

#     def forward(self, x):
#         x = x.flatten(start_dim=2).permute(2, 0, 1)  # NCHW -> (HW)NC
#         x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
#         x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
#         x, _ = F.multi_head_attention_forward(
#             query=x[:1], key=x, value=x,
#             embed_dim_to_check=x.shape[-1],
#             num_heads=self.num_heads,
#             q_proj_weight=self.q_proj.weight,
#             k_proj_weight=self.k_proj.weight,
#             v_proj_weight=self.v_proj.weight,
#             in_proj_weight=None,
#             in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
#             bias_k=None,
#             bias_v=None,
#             add_zero_attn=False,
#             dropout_p=0,
#             out_proj_weight=self.c_proj.weight,
#             out_proj_bias=self.c_proj.bias,
#             use_separate_proj_weight=True,
#             training=self.training,
#             need_weights=False
#         )
#         return x.squeeze(0)


class ModifiedResNet(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """

    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
        super().__init__()
        self.output_dim = output_dim
        self.input_resolution = input_resolution

        # the 3-layer stem
        self.conv1 = nn.Conv3d(1, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(width // 2)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(width // 2, width // 2, kernel_size=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(width // 2)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv3d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm3d(width)
        self.relu3 = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool3d(2)
        self.linear = nn.Linear(2048, output_dim)

        # residual layers
        self._inplanes = width  # this is a *mutable* variable used during construction
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

        embed_dim = width * 32  # the ResNet feature dimension
#         self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)

    def _make_layer(self, planes, blocks, stride=1):
        layers = [Bottleneck(self._inplanes, planes, stride)]

        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        def stem(x):
            x = self.relu1(self.bn1(self.conv1(x)))
            x = self.relu2(self.bn2(self.conv2(x)))
            x = self.relu3(self.bn3(self.conv3(x)))
            x = self.avgpool(x)
            return x

        x = x.type(self.conv1.weight.dtype)
        x = stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
#         x = self.avgpool(x) #changed final attentionpool to avgpool
        #x = self.attnpool(x)
        x = x.view(-1,2048) #will have to change this if different layer/kernel sizes used
        x = self.linear(x)
        return x


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)


class CLIP(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int
                 ):
        super().__init__()

        self.context_length = context_length

        #initializes resnet (removed option for vision transformer)
        vision_heads = vision_width * 32 // 64
        self.visual = ModifiedResNet(
            layers=vision_layers,
            output_dim=embed_dim,
            heads=vision_heads,
            input_resolution=image_resolution,
            width=vision_width
        )

        #initializes text transformer
        self.transformer = Transformer(
            width=transformer_width,
            layers=transformer_layers,
            heads=transformer_heads,
            attn_mask=self.build_attention_mask()
        )

        self.vocab_size = vocab_size
        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
        self.ln_final = LayerNorm(transformer_width)

        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
        #self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.logit_scale = torch.ones([]) * np.log(1 / 0.07)

        self.initialize_parameters()

    def initialize_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)

        if isinstance(self.visual, ModifiedResNet):
#             if self.visual.attnpool is not None:
#                 std = self.visual.attnpool.c_proj.in_features ** -0.5
#                 nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
#                 nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
#                 nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
#                 nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)

            for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
                for name, param in resnet_block.named_parameters():
                    if name.endswith("bn3.weight"):
                        nn.init.zeros_(param)

        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        attn_std = self.transformer.width ** -0.5
        fc_std = (2 * self.transformer.width) ** -0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.context_length, self.context_length)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    @property
    def dtype(self):
        return self.visual.conv1.weight.dtype

    def encode_image(self, image):
        return self.visual(image.type(self.dtype))

    def encode_text(self, text):
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x

    def forward(self, image, text):
        image_features = self.encode_image(image)
        text_features = self.encode_text(text)

        # normalized features
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        # shape = [global_batch_size, global_batch_size]
        return logits_per_image, logits_per_text


def convert_weights(model: nn.Module):
    """Convert applicable model parameters to fp16"""

    def _convert_weights_to_fp16(l):
        if isinstance(l, (nn.Conv1d, nn.Conv3d, nn.Linear)):
            l.weight.data = l.weight.data.half()
            if l.bias is not None:
                l.bias.data = l.bias.data.half()

        if isinstance(l, nn.MultiheadAttention):
            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.half()

        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:
                    attr.data = attr.data.half()

    model.apply(_convert_weights_to_fp16)


def build_model(state_dict: dict):
    vit = "visual.proj" in state_dict

    if vit:
        vision_width = state_dict["visual.conv1.weight"].shape[0]
        vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_resolution = vision_patch_size * grid_size
    else:
        counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
        image_resolution = output_width * 32

    embed_dim = state_dict["text_projection"].shape[1]
    context_length = state_dict["positional_embedding"].shape[0]
    vocab_size = state_dict["token_embedding.weight"].shape[0]
    transformer_width = state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))

    model = CLIP(
        embed_dim,
        image_resolution, vision_layers, vision_width, vision_patch_size,
        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in state_dict:
            del state_dict[key]

    convert_weights(model)
    model.load_state_dict(state_dict)
    return model.eval()

In [None]:
_MODELS = {
    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
    "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
    "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
    "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
    "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
    "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
}

In [None]:
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm

def _download(url: str, root: str):
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)

    expected_sha256 = url.split("/")[-2]
    download_target = os.path.join(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    if os.path.isfile(download_target):
        if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
            return download_target
        else:
            warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
        raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")

    return download_target

In [None]:
name="RN50"
if name in _MODELS:
    model_path = _download(_MODELS[name], os.path.expanduser("~/.cache/clip"))
else:
    raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

In [None]:
jit=True
with open(model_path, 'rb') as opened_file:
    try:
        # loading JIT archive
        model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
        state_dict = model.state_dict()
    except RuntimeError:
        # loading saved state dict
        if jit:
            warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
            jit = False
        state_dict = torch.load(opened_file, map_location="cpu")

In [None]:
#vision
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
print(vision_layers)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
print(output_width)
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32

#transformer
embed_dim = state_dict["text_projection"].shape[1]
print(embed_dim)
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))

In [None]:
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

In [None]:
#trains the clip model from scratch
def train_clip(model, text_samples, image_samples, batch_size=10, num_epochs=100, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.2)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        epoch_loss = 0
        image_epoch_correct = 0
        text_epoch_correct = 0
        epoch_total = 0
        text_samples, image_samples = unison_shuffled_copies(text_samples, image_samples)
        for batch in range(math.floor(image_samples.shape[0]/batch_size)):
            optimizer.zero_grad()
            
            #gets embeddings for text and image batches
            start_idx = batch*batch_size
            end_idx = (batch+1)*batch_size
            text_batch, image_batch = text_samples[start_idx:end_idx], image_samples[start_idx:end_idx]
            
            logits_per_image, logits_per_text = model(torch.unsqueeze(image_batch, dim=1), text_batch)
#             print(logits_per_image.shape)
#             print(logits_per_image)
            
            #symmetric loss function
            labels = torch.arange(batch_size).to(device)
            loss_text = loss_fn(logits_per_text, labels)
            loss_image = loss_fn(logits_per_image, labels)
            loss = (loss_text + loss_image)/2
#             print("\tBatch:", batch, "/", math.floor(image_samples.shape[0]/batch_size), ", Loss:", loss)
            #print(loss)
            loss.backward()
            epoch_loss += loss.item()
            optimizer.step()
            
            #compute accuracy
            image_winners = torch.argmax(logits_per_image, dim=0)
            text_winners = torch.argmax(logits_per_text, dim=0)
            #print(winners)
            #print(labels)
            image_corrects = (image_winners == labels)
            text_corrects = (text_winners == labels)
            total_image_correct = image_corrects.sum().float().item()
            total_text_correct = text_corrects.sum().float().item()
            image_epoch_correct += total_image_correct
            text_epoch_correct += total_text_correct
            epoch_total += batch_size
            print("\tBatch:", batch, "/", math.floor(image_samples.shape[0]/batch_size), ", Loss:", loss.item(), ", Image Accuracy:", total_image_correct/batch_size , ", Text Accuracy:", total_text_correct/batch_size)
        if epoch % 1 == 0:
            print("Epoch:", epoch, "Training Loss:", epoch_loss, "Training Image Accuracy:", image_epoch_correct/epoch_total, "Training Text Accuracy:", text_epoch_correct/epoch_total)

### Train using real fMRI and text data

In [None]:
print(tokenize(["Harry potter was a", "wizard and he always"]))
print(tokenize(["Harry potter was a", "wizard and he always"]).shape)

In [None]:
def split_samples(samples, min_val=None, max_val=None):
    images = torch.zeros([len(samples)] + list(samples[0][0].shape))
    text = torch.zeros((len(samples), 16), dtype=int)
    for idx, sample in enumerate(samples):
        images[idx] = torch.tensor(sample[0])
        text[idx] = tokenize([" ".join(sample[1])])
    if min_val is None:
        min_val = images.min()
        max_val = images.max()
    images = (images - min_val)/(max_val - min_val)
    return images.to(device), text.to(device), min_val, max_val

In [None]:
train_split = 0.8

random.shuffle(subjects_samples[0])

train_samples = subjects_samples[0][:int(len(subjects_samples[0])*train_split)]
train_images, train_text, min_val, max_val = split_samples(train_samples)
print("Train:")
print(train_text.shape)
print(train_images.shape)
if torch.isnan(torch.sum(train_text)) or torch.isinf(torch.sum(train_text)):
    print('invalid input detected in text.')
if torch.isnan(torch.sum(train_images)) or torch.isinf(torch.sum(train_images)):
    print('invalid input detected in images.')

test_samples = subjects_samples[0][int(len(subjects_samples[0])*train_split):]
test_images, test_text, _, _ = split_samples(test_samples, min_val, max_val)
print("Test:")
print(len(test_text))
print(test_images.shape)

In [None]:
clip_model = CLIP(
    256, #embed dim
    image_resolution, 
    (2,2,2,2), #image encoder layers 
    64, #image encoder width
    16, #token context length
    vocab_size, 
    transformer_width, 
    transformer_heads, 
    transformer_layers
).to(device)

### Train model

In [None]:
train_clip(clip_model, train_text, train_images, batch_size=64, lr=1e-4, num_epochs=50)

In [None]:
# for alpha in (1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8, 1e-9):
#     print("Alpha:", alpha)
#     clip_model = CLIP(
#         embed_dim,
#         image_resolution, (2,2,2,2), 64,
#         16, vocab_size, transformer_width, transformer_heads, transformer_layers
#     ).to(device)
#     train_clip(clip_model, train_text, train_images, batch_size=64, lr=alpha, num_epochs=10)

In [None]:
torch.save(clip_model.state_dict(), "./saved_clip_model")

In [None]:
# clip_model = CLIP(
#     embed_dim,
#     image_resolution, (2,2,2,2), 64,
#     16, vocab_size, transformer_width, transformer_heads, transformer_layers
# ).to(device)
# clip_model.load_state_dict(torch.load("./saved_clip_model"))
# clip_model.eval()

### Assess training accuracy

In [None]:
total_samples = 0
total_image_correct = 0
total_text_correct = 0
total_loss = 0

batch_size=64

loss_fn = nn.CrossEntropyLoss()

text_samples, image_samples = unison_shuffled_copies(train_text, train_images)
for batch in range(math.floor(image_samples.shape[0]/batch_size)):
            
    #gets embeddings for text and image batches
    start_idx = batch*batch_size
    end_idx = (batch+1)*batch_size
    text_batch, image_batch = text_samples[start_idx:end_idx], image_samples[start_idx:end_idx]

    logits_per_image, logits_per_text = clip_model(torch.unsqueeze(image_batch, dim=1), text_batch)

    #symmetric loss function
    labels = torch.arange(batch_size).to(device)
    loss_text = loss_fn(logits_per_text, labels)
    loss_image = loss_fn(logits_per_image, labels)
    loss = (loss_text + loss_image)/2
    total_loss += loss.item()
    
    #compute accuracy
    image_winners = torch.argmax(logits_per_image, dim=0)
    text_winners = torch.argmax(logits_per_text, dim=0)
    image_corrects = (image_winners == labels)
    text_corrects = (text_winners == labels)
    image_correct_num = image_corrects.sum().float().item()
    text_correct_num = text_corrects.sum().float().item()
    total_image_correct += image_correct_num
    total_text_correct += text_correct_num
    print(image_correct_num)
    print(text_correct_num)
    total_samples += batch_size
    
    print("\tBatch:", batch, "/", math.floor(image_samples.shape[0]/batch_size), ", Loss:", loss, ", Image Accuracy:", image_correct_num/batch_size , ", Text Accuracy:", text_correct_num/batch_size)

print("Total Training Loss:", total_loss)
print("Image Training Accuracy: ", total_image_correct/total_samples, "Text Training Accuracy: ", total_text_correct/total_samples)

### Assess testing accuracy

In [None]:
test_loss = 0
image_test_correct = 0
text_test_correct = 0
test_total = 0

batch_size=64
loss_fn = nn.CrossEntropyLoss()

text_samples, image_samples = unison_shuffled_copies(test_text, test_images)

for batch in range(math.floor(image_samples.shape[0]/batch_size)):

    #gets embeddings for text and image batches
    start_idx = batch*batch_size
    end_idx = (batch+1)*batch_size
    text_batch, image_batch = text_samples[start_idx:end_idx], image_samples[start_idx:end_idx]

    logits_per_image, logits_per_text = clip_model(torch.unsqueeze(image_batch, dim=1), text_batch)

    #symmetric loss function
    labels = torch.arange(batch_size).to(device)
    loss_text = loss_fn(logits_per_text, labels)
    loss_image = loss_fn(logits_per_image, labels)
    loss = (loss_text + loss_image)/2
    test_loss += loss.item()

    #compute accuracy
    image_winners = torch.argmax(logits_per_image, dim=0)
    text_winners = torch.argmax(logits_per_text, dim=0)

    image_corrects = (image_winners == labels)
    text_corrects = (text_winners == labels)
    total_image_correct = image_corrects.sum().float().item()
    total_text_correct = text_corrects.sum().float().item()
    image_test_correct += total_image_correct
    text_test_correct += total_text_correct
    test_total += batch_size
    
    print("\tBatch:", batch, "/", math.floor(image_samples.shape[0]/batch_size), ", Loss:", loss, ", Image Accuracy:", total_image_correct/batch_size , ", Text Accuracy:", total_text_correct/batch_size)
    
print("Total Training Loss:", test_loss)
print("Image Training Accuracy: ", image_test_correct/test_total, "Text Training Accuracy: ", text_test_correct/test_total)