In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
from fastai.vision.all import *
from self_supervised.layers import *
import sklearn

In [3]:
datapath = Path("../data/stanford-dogs-dataset/")

In [4]:
train_df = pd.read_csv(datapath/'train.csv')
test_df = pd.read_csv(datapath/'test.csv')
sample_df = pd.read_csv(datapath/'sample_train.csv')

In [5]:
train_df.shape, test_df.shape, sample_df.shape

((12000, 2), (8580, 2), (6000, 3))

In [6]:
# train_df.head()
# test_df.head()
# sample_df.head()

### Dataset

In [7]:
def read_image(filename): return PILImage.create(datapath/'images/Images'/filename)
def read_image_size(filename): return PILImage.create(datapath/'images/Images'/filename).shape

In [8]:
FAST = True

In [9]:
if FAST:
    filenames = sample_df['filename'].values
    labels = sample_df['label'].values
    fn2label = dict(zip(filenames, labels))
else:
    filenames = train_df['filenames'].values
    labels = train_df['labels'].values
    fn2label = dict(zip(filenames, labels))

In [10]:
def read_label(filename): return fn2label[filename]

In [11]:
valid_filenames = sample_df.query("split == 'valid'")['filename'].values

In [12]:
size,bs = 384,32

tfms = [[read_image, ToTensor, RandomResizedCrop(size, min_scale=.75)], 
        [read_label, Categorize()]]

valid_splitter = lambda o: True if o in valid_filenames else False 
dsets = Datasets(filenames, tfms=tfms, splits=FuncSplitter(valid_splitter)(filenames))

batch_augs = aug_transforms()
# batch_augs = []

stats = imagenet_stats

batch_tfms = [IntToFloatTensor] + batch_augs + [Normalize.from_stats(*stats)]
dls = dsets.dataloaders(bs=bs, after_batch=batch_tfms)

In [13]:
len(dls.train_ds), len(dls.valid_ds)

(4800, 1200)

In [14]:
# dls.show_batch()

### Modifications on ViT

In [15]:
from utils.custom_vit import *

In [16]:
# timm vit _encoder
arch = "vit_base_patch16_384"
_encoder = create_encoder(arch, pretrained=True, n_in=3)

In [17]:
# custom vit encoder with timm weights
encoder = VisionTransformer(img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12)
encoder.head = Identity()
encoder.load_state_dict(_encoder.state_dict());

In [18]:
# grad checkpointing
encoder = CheckpointVisionTransformer(encoder, 12)

In [19]:
# 1) Change Stride Size

patch_size,stride_size = 16,16

# new_patch_embed = PatchEmbed(size, patch_size, stride_size)
# new_patch_embed.proj.weight.data = encoder.vit_model.patch_embed.proj.weight.data
# new_patch_embed.proj.bias.data = encoder.vit_model.patch_embed.proj.bias.data
# encoder.vit_model.patch_embed = new_patch_embed

# 2) Interpolate Position Embeddings to new Number of Patches

# num_patches = ((size - patch_size + stride_size) // stride_size)**2 + 1

# pos_embed_data = encoder.vit_model.pos_embed.data
# new_pos_embed_data = F.interpolate(pos_embed_data[None, ...], 
#                                    size=[num_patches, pos_embed_data.size(-1)], 
#                                    mode='nearest')[0]
# encoder.vit_model.pos_embed.data = new_pos_embed_data

3) Create Model

In [20]:
with torch.no_grad():
    out, attn_wgts = encoder(torch.randn(2,3,size,size))
    nf = out.size(1)
classifier = create_cls_module(nf, dls.c, lin_ftrs=[768], use_bn=False, first_bn=False, ps=0.)



In [21]:
class FGVCModel(Module):
    def __init__(self, encoder, classifier, return_attn_wgts=False):
        self.encoder = encoder
        self.classifier = classifier
        self.return_attn_wgts = return_attn_wgts
        
    def forward(self, x):
        cls_token,attn_wgts = self.encoder(x)
        if self.return_attn_wgts: return self.classifier(cls_token), attn_wgts, cls_token
        else:                     return self.classifier(cls_token)  

In [22]:
model = FGVCModel(encoder, classifier)

In [23]:
attn_wgts[0].shape

torch.Size([2, 12, 577, 577])

In [24]:
model.classifier

Sequential(
  (0): Linear(in_features=768, out_features=768, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=768, out_features=120, bias=True)
)

In [25]:
# def model_splitter(m): return L(m[0], m[1]).map(params)
def model_splitter(m): return L(m.encoder, m.classifier).map(params)

In [26]:
cbs = []
# if WANDB: cbs += [WandbCallback(log_preds=False,log_model=False)]
learn = Learner(dls, model, opt_func=ranger, cbs=cbs, metrics=[accuracy], splitter=model_splitter,
                loss_func=LabelSmoothingCrossEntropyFlat(0.1))
learn.to_fp16();

### Train for Warmup

In [27]:
# learn.lr_find()

In [28]:
# epochs = 2

# lr = 3e-3
# learn.freeze()
# learn.fit_one_cycle(epochs, lr_max=(lr), pct_start=0.5)

# lr /= 3 
# learn.unfreeze()
# learn.fit_one_cycle(int(epochs**2), lr_max=slice(lr/100, lr), pct_start=0.5)

# learn.save(f"{arch}_stride_{stride_size}_imsize_{size}")

In [29]:
learn.load(f"{arch}_stride_{stride_size}_imsize_384");

In [30]:
learn.validate()

(#2) [1.0573927164077759,0.9158333539962769]

### Multi Crop Dataset : 1 x (384 px whole image) + 2 x (448 px -> 112 px crops)

In [31]:
from utils.attention import *

In [32]:
size,bs = 448,16

tfms = [[read_image, ToTensor, RandomResizedCrop(size, min_scale=.75)], 
        [read_label, Categorize()]]

valid_splitter = lambda o: True if o in valid_filenames else False 
dsets = Datasets(filenames, tfms=tfms, splits=FuncSplitter(valid_splitter)(filenames))

batch_augs = aug_transforms()
# batch_augs = []

stats = imagenet_stats

batch_tfms = [IntToFloatTensor] + batch_augs + [Normalize.from_stats(*stats)]
dls = dsets.dataloaders(bs=bs, after_batch=batch_tfms)

In [33]:
pretrained_vit_encoder = learn.model.encoder.vit_model

In [34]:
xb_448,yb = dls.one_batch()

In [35]:
xb_384 = F.interpolate(xb_448, size=(384,384))

In [36]:
xb_448.shape, xb_384.shape

(torch.Size([16, 3, 448, 448]), torch.Size([16, 3, 384, 384]))

In [37]:
from torch.utils.checkpoint import checkpoint
class FullImageEncoder(Module):
    "Encoder which takes whole image input then outputs attention weights + layer features"
    def __init__(self, pretrained_vit_encoder, nblocks=11, checkpoint_nchunks=2, return_attn_wgts=True):
                
        # initialize params with warm up model
        self.patch_embed = pretrained_vit_encoder.patch_embed
        self.cls_token = pretrained_vit_encoder.cls_token
        self.pos_embed = pretrained_vit_encoder.pos_embed
        self.pos_drop = pretrained_vit_encoder.pos_drop
        
        # until layer n-1, can be changed (memory trade-off)
        self.blocks = pretrained_vit_encoder.blocks[:nblocks]        
        
        # not needed now
#         self.norm = pretrained_vit_encoder.norm
        
        # gradient checkpointing
        self.checkpoint_nchunks = checkpoint_nchunks
        
        self.return_attn_wgts = return_attn_wgts
         
    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        # collect attn_wgts from all layers
        attn_wgts = []
        if self.return_attn_wgts:
            for i,blk in enumerate(self.blocks):
                if i<self.checkpoint_nchunks: x,attn_wgt = checkpoint(blk, x)
                else:                         x,attn_wgt = blk(x)
                attn_wgts.append(attn_wgt)
            return x,attn_wgts
        
        else:
            for i,blk in enumerate(self.blocks):
                if i<self.checkpoint_nchunks: x,_ = checkpoint(blk, x)
                else:                         x,_ = blk(x)
            return x
        
    def forward(self, x):
        return self.forward_features(x)

In [38]:
pretrained_vit_encoder = learn.model.encoder.vit_model
full_image_encoder = FullImageEncoder(pretrained_vit_encoder, nblocks=11, checkpoint_nchunks=2).cuda() # init full iamge encoder

In [39]:
x, attn_wgts = full_image_encoder(xb_384)

In [40]:
x.shape, len(attn_wgts)

(torch.Size([16, 577, 768]), 11)

In [41]:
def generate_batch_attention_maps(attn_wgts, targ_sz=None, mode=None):
    "Generate attention flow maps with shape (targ_sz,targ_sz) from L layer attetion weights of transformer model"
    # Stack for all layers - BS x L x K x gx x gy
    att_mat = torch.stack(attn_wgts, dim=1)
    # Average the attention weights across all heads.
    att_mat = torch.mean(att_mat, dim=2)
    # To account for residual connections, we add an identity matrix to the
    aug_att_mat = att_mat + torch.eye(att_mat.size(-1))[None,None,...].to(att_mat.device)
    # Re-normalize the weights.
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = aug_att_mat[:,0]
    for n in range(1, aug_att_mat.size(1)): joint_attentions = torch.bmm(aug_att_mat[:,n], joint_attentions)

    # BS x (num_patches+1) -> BS x gx x gy
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    joint_attentions = joint_attentions[:,0,1:].view(joint_attentions.size(0),grid_size,grid_size)
    joint_attentions /= torch.amax(joint_attentions, dim=(-2,-1), keepdim=True)

    # Bilinear interpolation to target size
    if mode == 'bilinear':
        joint_attentions = F.interpolate(joint_attentions[None,...], 
                                         (targ_sz,targ_sz), 
                                         mode=mode, align_corners=True)[0].detach().cpu().numpy()
    elif mode == 'nearest':
        joint_attentions = F.interpolate(joint_attentions[None,...], 
                                         (targ_sz,targ_sz), 
                                         mode=mode)[0].detach().cpu().numpy()
    elif mode is None:
        joint_attentions = joint_attentions
    
    return joint_attentions

In [42]:
attention_maps = generate_batch_attention_maps(attn_wgts, None, mode=None)

In [43]:
attention_maps.shape

torch.Size([16, 24, 24])

In [44]:
class SpatialTransfomerBlock(Module):
    def __init__(self):
        # Spatial transformer localization-network
        self.localization = nn.Sequential(
                            nn.MaxPool2d(2, stride=3),
        )
        
        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(True),
            nn.Linear(64, 2*4)
        )
        
        # ranges for each geometrc transform
        self.scale_range = SigmoidRange(0,1) # between 0-1 we always crop
        self.translate_range = SigmoidRange(-1,1) # stay inside image
        
    
    # Spatial transformer network forward function
    def forward(self, x):
        xs = self.localization(x)
        xs = xs.view(xs.size(0), -1)
        theta = self.fc_loc(xs)
        theta = theta.view(x.size(0),2,4)
        theta = torch.cat([self.scale_range(theta[:,:,:2]), self.translate_range(theta[:,:,2:])], dim=-1)

        # scalex,scaley,translatex,translatey -> affine matrix
        # [scalex,      0,   scalex*translatex]
        # [0     , scaley,   scaley*translatey]
        zeros = torch.zeros(theta.size(0),theta.size(1)).to(theta.device)
        theta = torch.stack([theta[:,:,0],
                             zeros,
                             theta[:,:,0]*theta[:,:,2],
                             zeros,
                             theta[:,:,1],
                             theta[:,:,1]*theta[:,:,3]], dim=-1).view(theta.size(0), theta.size(1), 2, 3)

        return theta # BS x 2 (num_crops) x 2 x 3
        
    def transform(self, x, theta, targ_sz=112):
        grid = F.affine_grid(theta, (x.size(0),x.size(1),targ_sz,targ_sz))
        out = F.grid_sample(x, grid)
        return out

In [45]:
st_model = SpatialTransfomerBlock().cuda()
theta_crops = st_model(attention_maps)

In [46]:
del attn_wgts
torch.cuda.empty_cache()

In [47]:
theta_crops[0,0], theta_crops[0,1]

(tensor([[ 0.5404,  0.0000, -0.0303],
         [ 0.0000,  0.4677,  0.0425]], device='cuda:0',
        grad_fn=<SelectBackward>),
 tensor([[ 0.4858,  0.0000, -0.0023],
         [ 0.0000,  0.4988,  0.0113]], device='cuda:0',
        grad_fn=<SelectBackward>))

In [48]:
xb_112_crop1 = st_model.transform(xb_448, theta_crops[:,0], targ_sz=112)
xb_112_crop2 = st_model.transform(xb_448, theta_crops[:,0], targ_sz=112)

In [49]:
xb_112_crop1.shape, xb_112_crop2.shape

(torch.Size([16, 3, 112, 112]), torch.Size([16, 3, 112, 112]))

In [50]:
crop_image_encoder = deepcopy(full_image_encoder).cuda() # init crop encoder as copy

In [51]:
crop_image_encoder.return_attn_wgts = False

In [52]:
# interpolate pos embed from 384 px -> 112 px
num_patches = ((112 - 16 + 16) // 16)**2 + 1
pos_embed_data = crop_image_encoder.pos_embed.data
new_pos_embed_data = F.interpolate(pos_embed_data[None, ...], 
                                   size=[num_patches, pos_embed_data.size(-1)], 
                                   mode='nearest')[0]
crop_image_encoder.pos_embed.data = new_pos_embed_data

In [53]:
x_crop1 = crop_image_encoder(xb_112_crop1)
x_crop2 = crop_image_encoder(xb_112_crop1)

In [57]:
x.shape, x_crop1.shape, x_crop2.shape

(torch.Size([16, 577, 768]),
 torch.Size([16, 50, 768]),
 torch.Size([16, 50, 768]))

In [59]:
x = torch.cat([x, x_crop1, x_crop2], dim=1)

In [60]:
x.shape

torch.Size([16, 677, 768])

In [62]:
final_block = Block(dim=768,num_heads=12,mlp_ratio=4.,qkv_bias=True,qk_scale=None).cuda()

In [63]:
x,_ = final_block(x)

In [71]:
norm = partial(nn.LayerNorm, eps=1e-6)(768).cuda()

In [72]:
x = norm(x)

In [76]:
cls_token = x[:,0]

In [79]:
classifier = create_cls_module(nf, dls.c, lin_ftrs=[768], use_bn=False, first_bn=False, ps=0.).cuda()

In [81]:
classifier(cls_token).shape

torch.Size([16, 120])

In [None]:
class STViT(Module):
    def __init__(self, pretrained_vit_encoder):
        
        
        self.full_image_encoder = FullImageEncoder(pretrained_vit_encoder, nblocks=11, checkpoint_nchunks=6)
        self.st_model = SpatialTransfomerBlock()
        self.crop_image_encoder = deepcopy(full_image_encoder)
        
        
        
    def forward(self,x):
        
        
        
        
        
        
        
        
        
        