In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from nosaveddata import *

import numpy as np


#model = ViT(128, 8, 4, first_channel=3).cuda()

#model(torch.randn(16,3,96,72).cuda()).shape

class Transformer_Block_NoLN(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.0, bias=False, ffn_mult=4, stochastic_depth=1):
        super().__init__()
        self.stochastic_depth=stochastic_depth
        self.ln_1 = LayerNormNoBias(d_model, bias=bias)
        self.attn = Attention(d_model, num_heads, bias, dropout)
        self.ln_2 = LayerNormNoBias(d_model, bias=bias)
        self.mlp = FFN(d_model, dropout, bias, ffn_mult)

    def forward(self, x, is_causal=True):
        #x = renormalize(x)
        keep_path = torch.ones(x.shape[0],device='cuda')*(self.stochastic_depth if self.training else 1)
        keep_path = torch.bernoulli(keep_path)[:,None,None]
        
        means, stds = 0, 0
        means += x.mean()
        stds += x.std()
        #x_ln = self.ln_1(x)
        x = x + self.attn(x, x, x, is_causal=is_causal)*keep_path
        means += x.mean()
        stds += x.std()
        x = x + self.mlp(x)*keep_path
        means += x.mean()
        stds += x.std()

        return x, means/3, stds/3

class Transformer_NoDATA(nn.Module):
    def __init__(self, d_model, num_blks, nhead, seq_len,
                 dropout = 0.1, bias=False, report_params_count=True,
                 ffn_mult=4, scale_init=1):
        super().__init__()
        self.num_hiddens = d_model
        self.scale_init=scale_init
        if scale_init==1:
            self.scale_init=num_blks


        self.pos_encoding = nn.Embedding(seq_len, d_model)

        self.final_ln = LayerNormNoBias(d_model)
        self.start_dropout = nn.Dropout(dropout)
        self.seq_len = seq_len
        self.num_blks=num_blks

        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("block"+str(i), Transformer_Block_NoLN(
                                d_model, nhead, dropout, bias=False, ffn_mult=ffn_mult,
                                stochastic_depth=1-i/num_blks*(0.9)))



        # https://proceedings.mlr.press/v119/huang20f/huang20f.pdf

        self.apply(init_xavier)
        self.apply(self._init_weights)

        for pn, p in self.named_parameters():
            if pn.endswith('proj.weight') or pn.endswith('W_v.weight') or pn.endswith('fc.weight') or pn.endswith('pos_encoding.weight'):
                torch.nn.init.xavier_uniform_(p, gain=(torch.tensor(4*self.scale_init,dtype=torch.float)).pow(-1/4))
        #for pn, p in self.named_parameters():
        #    if pn.endswith('proj.weight'):
        #        torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * num_blks))

        if report_params_count:
            params_to_count = [p for p in self.parameters() if p.requires_grad]
            print(f'GPT Transformer Parameters: {sum(p.numel() for p in params_to_count)/1e6:.2f}M')

    def _init_weights(self, module):
        
        if isinstance(module, nn.Embedding):
            #torch.nn.init.normal_(module.weight, mean=0.0, std=1/math.sqrt(self.num_hiddens))
            torch.nn.init.xavier_uniform_(module.weight, gain=(torch.tensor(4*self.scale_init,dtype=torch.float)).pow(-1/4))
        '''
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
                
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.bias, 0)
            nn.init.constant_(module.weight, 1.0)
        '''

    def forward(self, X, is_causal=True):

        pos = torch.arange(0, self.seq_len, dtype=torch.long, device='cuda')
        pos_emb = self.pos_encoding(pos)
        print(X.shape, pos_emb.shape)
        X = self.start_dropout(X+pos_emb)
        print(X.shape)
        X = self.final_ln(X)

        means, stds = 0, 0
        for i, blk in enumerate(self.blks):
            X, mean, std = blk(X, is_causal)
            means += mean
            stds += std

        return X, means/self.num_blks, stds/self.num_blks
    
    def no_pos(self, X, is_causal=True):
        X = self.start_dropout(X)
        X = self.final_ln(X)
        
        for i, blk in enumerate(self.blks):
            X, _, _ = blk(X, is_causal)

        return X
    
    def masked(self, X, mask, is_causal=True):

        pos = torch.arange(0, self.seq_len, dtype=torch.float32, device='cuda')
        pos_emb = self.pos_encoding(pos)
        X = self.start_dropout(X+pos_emb)
        X = X.gather(1, mask)
        
        X = self.final_ln(X)

        
        for i, blk in enumerate(self.blks):
            X, _, _ = blk(X, is_causal)

        return X

model = Transformer_NoDATA(128, 8, 4, seq_len=128, dropout=0).cuda()
model(torch.randn(32,128,128).cuda())[0].shape, model.blks[1].attn.W_v.weight[0,0:3], model.blks[1].mlp.fc.weight[0,0:3], model.blks[1].mlp.proj.weight[0,0:3]

GPT Transformer Parameters: 1.59M
torch.Size([32, 128, 128]) torch.Size([128, 128])
torch.Size([32, 128, 128])


(torch.Size([32, 128, 128]),
 tensor([-0.0277,  0.0546,  0.0071], device='cuda:0', grad_fn=<SliceBackward0>),
 tensor([0.0209, 0.0240, 0.0252], device='cuda:0', grad_fn=<SliceBackward0>),
 tensor([ 0.0340, -0.0399,  0.0073], device='cuda:0', grad_fn=<SliceBackward0>))

In [121]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from nosaveddata import *


class Transformer_Block_NoLN(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.0, bias=False, ffn_mult=4, stochastic_depth=1):
        super().__init__()
        self.stochastic_depth=stochastic_depth
        self.ln_1 = LayerNormNoBias(d_model, bias=bias)
        self.attn = Attention(d_model, num_heads, bias, dropout)
        self.ln_2 = LayerNormNoBias(d_model, bias=bias)
        self.mlp = FFN(d_model, dropout, bias, ffn_mult)

    def forward(self, x, is_causal=True):
        #x = renormalize(x)
        keep_path = torch.ones(x.shape[0],device='cuda')*(self.stochastic_depth if self.training else 1)
        keep_path = torch.bernoulli(keep_path)[:,None,None]
        
        means, stds = 0, 0
        means += x.mean()
        stds += x.std()
        x_ln = self.ln_1(x)
        x = x + self.attn(x_ln, x_ln, x_ln, is_causal=is_causal)*keep_path
        means += x.mean()
        stds += x.std()
        x = x + self.mlp(self.ln_2(x))*keep_path
        means += x.mean()
        stds += x.std()

        return x, means/3, stds/3

class Transformer_NoDATA(nn.Module):
    def __init__(self, d_model, num_blks, nhead, seq_len,
                 dropout = 0.1, bias=False, report_params_count=True,
                 ffn_mult=4, scale_init=1):
        super().__init__()
        self.num_hiddens = d_model
        self.scale_init=scale_init
        if scale_init==1:
            self.scale_init=num_blks


        self.pos_encoding = nn.Embedding(seq_len, d_model)

        self.final_ln = LayerNormNoBias(d_model)
        self.start_dropout = nn.Dropout(dropout)
        self.seq_len = seq_len
        self.num_blks=num_blks

        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("block"+str(i), Transformer_Block_NoLN(
                                d_model, nhead, dropout, bias=False, ffn_mult=ffn_mult,
                                stochastic_depth=1-((1-0.9)*i/num_blks) )) 



        # https://proceedings.mlr.press/v119/huang20f/huang20f.pdf

        self.apply(init_gpt)
        self.apply(self._init_weights)

        for pn, p in self.named_parameters():
            if pn.endswith('proj.weight') or pn.endswith('W_v.weight') or pn.endswith('fc.weight') or pn.endswith('pos_encoding.weight'):
                torch.nn.init.xavier_uniform_(p, gain=(torch.tensor(4*self.scale_init,dtype=torch.float)).pow(-1/4))
        #for pn, p in self.named_parameters():
        #    if pn.endswith('proj.weight'):
        #        torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * num_blks))

        if report_params_count:
            params_to_count = [p for p in self.parameters() if p.requires_grad]
            print(f'GPT Transformer Parameters: {sum(p.numel() for p in params_to_count)/1e6:.2f}M')

    def _init_weights(self, module):
        
        if isinstance(module, nn.Embedding):
            #torch.nn.init.normal_(module.weight, mean=0.0, std=1/math.sqrt(self.num_hiddens))
            torch.nn.init.xavier_uniform_(module.weight, gain=(torch.tensor(4*self.scale_init,dtype=torch.float)).pow(-1/4))
        '''
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
                
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.bias, 0)
            nn.init.constant_(module.weight, 1.0)
        '''

    def forward(self, X, is_causal=True):

        pos = torch.arange(0, self.seq_len, dtype=torch.long, device='cuda')
        pos_emb = self.pos_encoding(pos)[:X.shape[1]]
        X = self.start_dropout(X+pos_emb)

        means, stds = 0, 0
        for i, blk in enumerate(self.blks):
            X, mean, std = blk(X, is_causal)
            means += mean
            stds += std
            
        X = self.final_ln(X)

        return X, means/self.num_blks, stds/self.num_blks
    
    def no_pos(self, X, is_causal=True):
        X = self.start_dropout(X)
        X = self.final_ln(X)
        
        for i, blk in enumerate(self.blks):
            X, _, _ = blk(X, is_causal)

        return X
    
    def masked(self, X, mask, is_causal=True):

        pos = torch.arange(0, self.seq_len, dtype=torch.long, device='cuda')
        pos_emb = self.pos_encoding(pos)[:X.shape[1]]
        X = self.start_dropout(X+pos_emb)
        X = X.gather(1, mask)
        
        X = self.final_ln(X)

        
        for i, blk in enumerate(self.blks):
            X, _, _ = blk(X, is_causal)

        return X
    

class ViT(nsd_Module):
    def __init__(self, d_model, num_blks, nhead, patches=(16,16), img_size=(96,72), first_channel=3,
                 dropout = 0, bias=True, report_params_count=True,
                 ffn_mult=4):
        super().__init__()


        self.patches = np.prod(patches)
        self.N = int(np.prod(img_size)/self.patches)

        self.in_proj = MLP(first_channel*self.patches, out_hiddens=d_model, last_init=init_gpt)

        self.cls = nn.Embedding(1,d_model)
        self.transformer = Transformer_NoDATA(d_model, num_blks, nhead, seq_len=self.N+1,
                 dropout = dropout, bias=bias, report_params_count=False,
                 ffn_mult=ffn_mult)

        self.cls.apply(init_emb)


    def patchify(self, X):
        X = X.view(-1, self.patches*self.first_channel, self.N).transpose(-2,-1)
        return X

    def proj(self, X):
        X = self.patchify(X)
        return self.in_proj(X)

    def transformers(self, X):
        cls = self.cls(torch.zeros(X.shape[0], device='cuda', dtype=torch.long))

        X = torch.cat((X,cls[:,None]), 1)
        X, means, std = self.transformer(X, is_causal=False)

        return X, means, std

    def forward(self, X):
        X = self.patchify(X)
        X = self.in_proj(X)

        cls = self.cls(torch.zeros(X.shape[0], device='cuda', dtype=torch.long))

        X = torch.cat((X,cls[:,None]), 1)
        X, means, std = self.transformer(X, is_causal=False)

        return X, means, std

    def masked(self, X, mask):
        
        X = self.transformer.masked(X, mask, is_causal=False)

        return X


class ViT_IWM(nsd_Module):
    def __init__(self, encoder,
                 d_predictor, num_blks_predictor, nhead_predictor,
                 stacked_frames=4,
                 mask_samples=4,
                 masked_tokens=4,
                 num_augmentations=3,
                 dropout = 0, bias=True, report_params_count=True,
                 ffn_mult=4, lr=1e-3):
        super().__init__()
        d_encoder = encoder.d_model
        self.d_encoder = d_encoder
        
        
        self.first_channel = encoder.first_channel*stacked_frames
        self.img_size = encoder.img_size
        self.patches = encoder.patches
        self.N = encoder.N
        self.masked_tokens=self.N//masked_tokens

        # Mask
        self.mask = MLP(1, out_hiddens=d_predictor, last_init=init_xavier)
        self.mask_pos_encoding = nn.Embedding(self.N, d_predictor)
        self.mask_mlp = MLP(d_predictor+num_augmentations, d_predictor, d_predictor, layers=4, in_act=nn.ReLU(), out_act=nn.ReLU(),
                            init=init_relu, last_init=init_gpt)
        self.mask_pos_encoding.apply(init_gpt)

        # Encoder
        self.encoder = encoder

        # Predictor
        self.predictor_proj = MLP(d_encoder, out_hiddens=d_predictor, last_init=init_gpt) \
                              if d_predictor!=d_encoder else nn.Identity()

        self.predictor = Transformer_NoDATA(d_predictor, num_blks_predictor, nhead_predictor, seq_len=self.N,
                 dropout = dropout, bias=bias, report_params_count=False,
                 ffn_mult=ffn_mult, scale_init=num_blks_predictor)


        self.predictor_out_proj = MLP(d_predictor, out_hiddens=d_encoder, last_init=init_gpt) \
                              if d_predictor!=d_encoder else nn.Identity()

        

        self.head = MLP(d_encoder, out_hiddens=100, layers=1, last_init=init_gpt)

        if report_params_count:
            params_count(self, 'IWM')

    def hard_reset(self, new_network, alpha):
        network_ema(self.encoder, new_network.encoder, alpha)

        network_ema(self.predictor_proj, new_network.predictor_proj, 0.3)
        network_ema(self.predictor, new_network.predictor, 0.3)

        network_ema(self.mask, new_network.mask, 0.3)
        network_ema(self.mask_pos_encoding, new_network.mask_pos_encoding, 0.3)
        network_ema(self.mask_mlp, new_network.mask_mlp, 0.3)

    def get_random_mask(self, X, augmentations):
        B, T, D = X.shape
        B = B//self.stacked_frames
        m_rand = self.mask_samples*random.randint(0,int(self.masked_tokens*2//self.mask_samples)-1)
        
        
        # Get non-overlapping mask
        mask_pos = torch.arange(T, device='cuda')[None,:].repeat_interleave(B,0).float()
        mask_pos = torch.multinomial(mask_pos, num_samples=self.masked_tokens+m_rand, replacement=False)
        
        mask_pos_repeat = mask_pos.repeat_interleave(self.stacked_frames,0)

        # Get the mask complement
        full_range = torch.arange(T,device='cuda')[None,:].repeat_interleave(B,0)

        complement = torch.zeros_like(full_range, dtype=torch.bool)
        complement.scatter_(1, mask_pos, 1)

        complement = full_range[~complement].view(mask_pos.shape[0], -1)
        

        # Mask mlp for geometric + augmentation informations
        mask = self.mask(torch.ones(B*self.stacked_frames,self.masked_tokens+m_rand,1, device='cuda'))

        mask = mask + self.mask_pos_encoding(mask_pos_repeat)

        augmentations = augmentations.repeat_interleave(self.stacked_frames,0)[:,None].expand(-1,mask.shape[1],-1)

        mask = self.mask_mlp(torch.cat((mask,augmentations),-1))

        # Expand to allow gather
        mask_pos = mask_pos[:,:,None].expand(-1,-1,X.shape[-1])
        complement = complement[:,:,None].expand(-1,-1,X.shape[-1])

        return X, mask_pos, complement, mask
    
    def get_block_mask(self, batch_size, M=4):
        
        all_wins = torch.zeros(self.first_channel,*self.img_size).long()
        
        b_mask, b_complement = [], []
        min_c_len = 999 # for trunked collate
        #min_m=999
        
        for b in range(batch_size):
            wins, complements = [], []
            for m in range(self.mask_samples):
                w,h = self.img_size


                min_ar, max_ar = (0.75, 1.5)
                aspect_ratio = min_ar + random.random() * (max_ar - min_ar)

                h_sample_size = int( (h*(torch.tensor(random.random())*0.05+0.15)) * aspect_ratio)

                w_wins, h_wins = torch.randint(0,h-h_sample_size,(2,)).split(1,0)
                win=all_wins.clone()


                for w_win, h_win in zip(w_wins, h_wins):
                    win[:,w_win:w_win+h_sample_size, h_win:h_win+h_sample_size]=1
                    
                win = self.encoder.patchify(win.float()).mean(-1)
                values, idx = win.sort(descending=True)

                idx = idx[:,:4]
                
                #min_m = min(min_m, len(values[0].nonzero()))
                wins.append(idx)


            wins = torch.stack(wins).squeeze()


            full_range = torch.arange(win.shape[1])

            complement = torch.zeros_like(full_range, dtype=torch.bool)
            complement.scatter_(0, wins.view(-1).unique(), 1)

            complement = full_range[~complement]
            min_c_len = min(min_c_len, len(complement))
            
            
            b_mask.append(wins)
            b_complement.append(complement)
            
            
        for i in range(len(b_complement)):
            b_complement[i] = b_complement[i][:min_c_len]
        
        b_mask = torch.stack(b_mask).cuda()
        b_complement = torch.stack(b_complement).cuda()
        #print(min_m)
        
        return b_mask, b_complement
    
    def get_mask(self, X, augmentations):
        B = X.shape[0]
        
        mask_pos, complement = self.get_block_mask(B)
        mask_pos = mask_pos.view(B*self.mask_samples,-1)
    
    
        mask_pos_repeat = mask_pos.repeat_interleave(self.stacked_frames,0)
        
        mask = self.mask(torch.ones(B*self.stacked_frames*self.mask_samples,1,1, device='cuda'))

        mask = mask + self.mask_pos_encoding(mask_pos_repeat)
        
        mask_pos = mask_pos[...,None].expand(-1,-1,self.d_encoder)
        complement = complement[...,None].expand(-1,-1,self.d_encoder)
        
        return mask_pos, mask, complement
    
    def encode(self, X):
        return self.encoder(X)
    def classify(self, X):
        return self.head(X[:,-1])

    def forward(self, X, y, augmentations):
        X = self.encoder.proj(X)
        mask_pos, mask, complement = self.get_mask(X, augmentations)
        
        
        X = self.encoder.masked(X, complement)
        X = self.predictor_proj(X)
        
        
        X = torch.cat((X.repeat_interleave(4,0),mask),1)
        
        X = self.predictor.no_pos(X)[:,-mask.shape[1]:]
        X = self.predictor_out_proj(X)
        #mask_pos = mask_pos.contiguous().view(X.shape[0], -1, X.shape[-1])

        return X, y.repeat_interleave(4,0).gather(1,mask_pos)


encoder = ViT(192, 9, 12, (4,8), (32,32),
                dropout=0, ffn_mult=2).cuda()
model = ViT_IWM(encoder, 96, num_blks_predictor=12, nhead_predictor=12,
                stacked_frames=1, dropout=0, ffn_mult=2).cuda()
y = encoder(torch.randn(32,3,32,32).cuda())[0]
augmentations = F.one_hot(torch.randint(0,1, (32,), device='cuda').long(),3).float()

x1 = model.encode(torch.randn(32,3,32,32).cuda())[0]
x1 = model.classify(x1)

y1 = F.one_hot(torch.randint(0,100,(32,),device='cuda'), 100)

x, y = model(torch.randn(32,3,32,32).cuda(), y, augmentations)


#loss = -(y1*torch.log(x1)).sum(-1).mean()

y = F.layer_norm(y,(y.shape[-1],))

mse = nn.MSELoss(reduction='none')

#x = F.normalize(x,dim=-1)
#y = F.normalize(y,dim=-1)

loss = mse(x,y).sum(-1).mean()*5

loss.backward()


model.encoder.transformer.blks[-1].attn.W_k.weight.grad.mean(), model.encoder.transformer.blks[-1].attn.W_k.weight.grad.max()

IWM Parameters: 3.67M


(tensor(1.6169e-12, device='cuda:0'), tensor(0.0763, device='cuda:0'))

In [4]:
from nosaveddata import *

encoder = ViT(192, 9, 12, (4,8), (32,32),
                dropout=0, ffn_mult=2, stochastic_depth=0.9).cuda()
model = ViT_IWM(encoder, 96, num_blks_predictor=12, nhead_predictor=12,
                stacked_frames=1, dropout=0, ffn_mult=2, stochastic_depth=0.9).cuda()

  from .autonotebook import tqdm as notebook_tqdm


TypeError: ViT.__init__() got an unexpected keyword argument 'stochastic_depth'

In [4]:
import shutil, glob, os

for file in glob.glob('a/*'):
    shutil.copy(file, f'b/{file.split(os.sep)[-1]}')

In [42]:
import math
import torch

p1 = 0.6697
p2 = 0.6649
n = 10000

def statistical_difference(p1, p2, n):
    
    d=torch.tensor(p1-p2).abs()

    std = 1.65 * math.sqrt((p1*(1-p1) + p2*(1-p2))/n)
    
    difference = torch.tensor([d-std, d+std])
    
    return difference.sort()[0]

print(statistical_difference(0.834, 0.831, 100000))

tensor([0.0002, 0.0058])


In [2]:

model = ViT_Temporal(128, 8, temporal_aggr_num_blks=1, nhead=4, first_channel=3).cuda()

model(torch.randn(16,12,96,72).cuda()).shape

GPT Transformer Parameters: 1.58M
GPT Transformer Parameters: 0.21M
ViT Temporal Parameters: 1.89M


torch.Size([16, 27, 128])

In [23]:

class ViT_IWM(nn.Module):
    def __init__(self, encoder, d_encoder,
                 d_predictor, num_blks_predictor, nhead_predictor,
                 out_dim=2048,
                 stacked_frames=4,
                 masked_tokens=4,
                 num_augmentations=3,
                 dropout = 0.1, bias=False, report_params_count=True,
                 ffn_mult=4):
        super().__init__()
        self.d_encoder = d_encoder
        self.stacked_frames=stacked_frames
        
        self.patches = encoder.patches
        self.N = encoder.N
        self.masked_tokens=self.N//masked_tokens
        
        self.encoder = encoder
        
        self.predictor_proj = MLP(d_encoder, out_hiddens=d_predictor, last_init=init_xavier) \
                              if d_predictor!=d_encoder else nn.Identity()
        
        self.predictor = GPT_Transformer(d_predictor, num_blks_predictor, nhead_predictor, seq_len=self.N,
                 dropout = dropout, bias=bias, report_params_count=report_params_count,
                 ffn_mult=ffn_mult)
        
        self.mask = MLP(1, out_hiddens=d_encoder, last_init=init_xavier)
        self.mask_pos_encoding = nn.Embedding(self.N, d_encoder)
        self.mask_mlp = MLP(d_encoder+num_augmentations, d_encoder, d_encoder, layers=4, in_act=nn.ReLU(), out_act=nn.ReLU(),
                            init=init_relu, last_init=init_relu)
        
        
        params_count(self, 'IWM')

    def hard_reset(self, new_network, alpha):
        network_ema(self.encoder, new_network.encoder, alpha)
        
        network_ema(self.predictor_proj, new_network.predictor_proj, 0)
        network_ema(self.predictor, new_network.predictor, 0)

        network_ema(self.mask, new_network.mask, 0)
        network_ema(self.mask_pos_encoding, new_network.mask_pos_encoding, 0)
        network_ema(self.mask_mlp, new_network.mask_mlp, 0)
    
    def get_mask(self, X, augmentations):
        B, T, D = X.shape
        B = B//self.stacked_frames
        m_rand = random.randint(0,self.masked_tokens*2)
        
        mask_pos = torch.randint(0, T, (B,self.masked_tokens+m_rand), device='cuda')
        mask_pos_repeat = mask_pos.repeat_interleave(self.stacked_frames,0)
        
        X_mask_pos = (mask_pos_repeat + torch.arange(B, device='cuda').repeat_interleave(self.stacked_frames,0)[:,None]*B).view(-1)
        
        
        mask = self.mask(torch.ones(B*self.stacked_frames,self.masked_tokens+m_rand,1, device='cuda'))
        
        mask = mask + self.mask_pos_encoding(mask_pos_repeat)
        augmentations = augmentations.repeat_interleave(self.stacked_frames,0)[:,None].expand(-1,mask.shape[1],-1)
        
        mask = self.mask_mlp(torch.cat((mask,augmentations),-1))
        
        
        X.view(-1,D)[X_mask_pos]=X.view(-1,D)[X_mask_pos]*0+mask.view(-1,D)
        
        mask_pos = mask_pos[:,:self.masked_tokens,None].expand(-1,-1,X.shape[-1])
        
        
        return X, mask_pos
    
    def encode(self, X):
        return self.encoder(X)

    
    def forward(self, X, y, augmentations):
        X = self.encoder.proj(X)
        X_masked, mask_pos = self.get_mask(X, augmentations)
        X = self.encoder.transformers(X_masked)
        
        X = self.predictor_proj(X)
        
        X = self.predictor(X)
        mask_pos = mask_pos.contiguous().view(X.shape[0], -1, X.shape[-1])
        
        return X.gather(1,mask_pos), y.gather(1,mask_pos)

encoder = ViT_Temporal(128, 8, patches=(8,8), temporal_aggr_num_blks=1, nhead=4, first_channel=3).cuda()

model = ViT_IWM(encoder, 128,
                128, 8, 4).cuda()

x = torch.randn(16,12,96,72).cuda()

with torch.no_grad():
    y = model.encode(torch.randn(16,12,96,72).cuda())

print(f"\npost temporal {y.shape}\n")
augmentations = torch.bernoulli(torch.ones(x.shape[0], 3)*0.2).cuda()

x, y_tgt = model(x, model.predictor_proj(y), augmentations)

loss = nn.MSELoss(reduction='none')

x=F.normalize(x)
y_tgt=F.normalize(y_tgt)

print(y.shape)


x.shape, y_tgt.shape, loss(x,y_tgt).sum(-1).mean()



#model.encode(torch.randn(32,12,96,72).cuda()).shape

GPT Transformer Parameters: 1.59M
GPT Transformer Parameters: 0.25M
ViT Temporal Parameters: 1.87M
GPT Transformer Parameters: 1.59M
IWM Parameters: 3.54M
post temporal torch.Size([16, 108, 128])
torch.Size([16, 108, 128])


(torch.Size([16, 27, 128]),
 torch.Size([16, 27, 128]),
 tensor(6.9166, device='cuda:0', grad_fn=<MeanBackward0>))

In [95]:
import torch
from torch import nn

def init_saving_variance(module, num_blks):
    
    torch.nn.init.xavier_uniform_(module.weight, gain=torch.tensor(9*num_blks).pow(-1/4))
    if hasattr(module, 'bias'):
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
            
model = nn.Embedding(32,512)
print(f"{model.weight[0,0]}")
init_saving_variance(model, 3)
print(f"{model.weight[0,0]}")

0.08193082362413406
0.025160208344459534


In [4]:
import math
import torch

a=torch.tensor(4*12)

a.pow(-1/4)

tensor(0.3799)

In [26]:
2%3, 192//12, 128//16

(2, 16, 8)

In [19]:
import math
import torch

a=torch.tensor(9*12)
b=torch.tensor(0.67*12)


math.sqrt(2/(512*2)), (a).pow(-1/4), (b).pow(-1/4)

(0.04419417382415922, tensor(0.3102), tensor(0.5939))

In [8]:
96/12, 72/12, 48*128

(8.0, 6.0, 6144)

In [5]:
0.15*196, 0.2*196

(29.4, 39.2)

In [6]:
model = IMPALA_Resnet(4,4)

IMPALA ResNet Parameters: 1.56M


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from nosaveddata import *

import numpy as np

In [8]:
class UNet_DiT(nn.Module):
    def __init__(self, in_channels, d_model, num_blks, nhead, patch=(2,2), img_size=(32,32),
                             dropout = 0.1, bias=False, report_params_count=True,
                             ffn_mult=4):
        super().__init__()
        self.first_channel=in_channels
        self.patches = np.prod(patch)
        self.img_size=img_size
        self.N = int(np.prod(img_size)/self.patches)
        
        self.ts = TimestepEmbedder(d_model)
        
        self.in_proj = MLP(in_channels*self.patches, out_hiddens=d_model, last_init=init_xavier)
        
        self.dit =  DiT_Transformer(d_model, num_blks, nhead, self.patches,
                             dropout = 0.1, bias=False, report_params_count=True,
                             ffn_mult=4)
        self.final_layer = DiT_FinalLayer(d_model, patch, in_channels)
        
        self.init_weights()
    
    def init_weights(self):
        # Zero-out output layers:
        self.final_layer.adaLN_modulation[-1].apply(init_zeros)
        self.final_layer.linear.apply(init_zeros)
    
    def patchify(self, X):
        X = X.view(-1, self.patches*self.first_channel, self.N).transpose(-2,-1)
        return X
    def depatchify(self, X):
        X = X.transpose(-2,-1).contiguous().view(-1, self.first_channel,*self.img_size)
        return X
    
    def forward(self, x, t):
        c = self.ts(t)
        
        x = self.patchify(x)
        x = self.in_proj(x)
        
        print(x.shape)
        x = self.dit(x, c)
        
        x = self.final_layer(x, c)
        x = self.depatchify(x)
        
        return x


class UNet_DiT_1D(nn.Module):
    def __init__(self, in_channels, d_model, num_blks, nhead, seq_len,
                             dropout = 0.1, bias=False, report_params_count=True,
                             ffn_mult=4):
        super().__init__()
        self.first_channel=in_channels
        
        self.ts = TimestepEmbedder(d_model)
        
        self.in_proj = MLP(in_channels, out_hiddens=d_model, last_init=init_xavier) if in_channels!=d_model else nn.Identity()
        
        self.dit =  DiT_Transformer(d_model, num_blks, nhead, seq_len,
                             dropout = 0.1, bias=False, report_params_count=True,
                             ffn_mult=4)
        
        self.out_proj = MLP(d_model, out_hiddens=in_channels, last_init=init_xavier) if in_channels!=d_model else nn.Identity()
        
    
    def forward(self, x, t):
        c = self.ts(t)
        x = self.in_proj(x)
        x = self.dit(x, c)
        x = self.out_proj(x)
        return x

In [9]:
#model = UNet_DiT(4, 512, 8, 8, patch=(4,4)).cuda()

#x=torch.randn(16,4,32,32).cuda()
x=torch.randn(16,33,512).cuda()
c=torch.randint(0,1000,(x.shape[0],)).cuda()


model = UNet_DiT_1D(512, 512, 8, 8, seq_len=33).cuda()
#model = UNet_DiT_XL_2(in_channels=4, img_size=(32,32)).cuda()
#model = UNet_DiT_XL_2(in_channels=4, patch=(2,2), img_size=(32,32)).cuda()

with torch.no_grad():
    print(model(x,c).shape)

DiT Transformer Parameters: 37.80M
torch.Size([16, 33, 512])


In [10]:
model = nn.Linear(10,2).cuda()
model.apply(init_xavier)
model2 = nn.Linear(10,2).cuda()
network_ema(model, model2, 0)
#model.apply(init_xavier)

model.weight.data==model2.weight.data

tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]],
       device='cuda:0')

<h1>Preprocessing</h1>

In [11]:
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
import os, glob
from nosaveddata import *


import torchvision
from torchvision import transforms

paths = glob.glob('C:/Users/Augusto/Python/PyTorch/RL/mc_data/4/2023_01_09_14_48_09_100636/*.jpg')
path = 'C:/Users/Augusto/Python/PyTorch/RL/mc_data/4/2023_01_09_14_48_09_100636/7,0,0,0,0,0,0,0,0,0,0,0,0,3,0,.jpg'



tfms = transforms.Compose([
                           transforms.Resize((96, 72)),
                           transforms.ToTensor()
                        ])

img = Image.open(path)
imgs=[]
for p in paths:
    imgs.append(tfms(Image.open(p)))
imgs=torch.stack(imgs)

print(imgs.shape)



imgs, augments_applied = preprocess_iwm_no_solarize(imgs)
    


#plt.imshow(img_tfms)
plot_imgs(imgs.permute(0,2,3,1))
augments_applied

FileNotFoundError: [Errno 2] No such file or directory: 'C:/Users/Augusto/Python/PyTorch/RL/mc_data/4/2023_01_09_14_48_09_100636/7,0,0,0,0,0,0,0,0,0,0,0,0,3,0,.jpg'

In [8]:
import torch
from torch import nn
import torch.nn.functional as F

from nosaveddata import *



def gray_scale_stacked(X, p=0.2, stacks=4):
    # Input: Tensor T e (B,C,T,D)
    
    probs = get_img_preprocessing_prob(X.shape[0], p, X.device)
    stacked_probs = probs.repeat_interleave(stacks,0)
    X = X.view(-1,X.shape[1]//stacks,*X.shape[-2:])
    
    gray_img = X.mean(1,keepdim=True).expand(-1,3,-1,-1)
    
    X = (1-stacked_probs)*X + stacked_probs*gray_img
    
    return X.view(X.shape[0]//stacks, -1, *X.shape[-2:]), probs.squeeze()

def gaussian_blur(X, p=0.2, stacks=4, sigma_min=0.1, sigma_max=2):
    # Input: Tensor T e (B,C,T,D)
    
    probs = get_img_preprocessing_prob(X.shape[0], p, X.device)
    tfms = transforms.GaussianBlur(3, (sigma_min, sigma_max))
    
    blurred = tfms(X)
    X = (1-probs)*X + probs*blurred
    
    return X, probs.squeeze()

def solarization_stacked(X, p=0.2, stacks=4):
    # Input: Tensor T e (B,C,T,D)

    probs = get_img_preprocessing_prob(X.shape[0], p, X.device)
    stacked_probs = probs.repeat_interleave(stacks,0)
    
    X = X.view(-1,X.shape[1]//stacks,*X.shape[-2:])
    
    tfms = transforms.RandomSolarize(0,p=1) # This prob is applied over all the batch or no image at all
    
    solarized = tfms(X)
    X = (1-stacked_probs)*X + stacked_probs*solarized
    
    return X.view(X.shape[0]//stacks, -1, *X.shape[-2:]), probs.squeeze()


def preprocess_iwm_stacked(imgs, p=0.2, stacks=4):
    # Applies the same preprocessing for all images in the sequence, but separated by each beach
    augments_applied=[]
    
    imgs, augmented = gray_scale_stacked(imgs, p, stacks)
    augments_applied.append(augmented)
    
    imgs, augmented = gaussian_blur_stacked(imgs, p, stacks)
    augments_applied.append(augmented)
    
    imgs, augmented = solarization_stacked(imgs, p, stacks)
    augments_applied.append(augmented)
    
    augments_applied = torch.stack(augments_applied,1)
    return imgs, augments_applied

preprocess_iwm_stacked(torch.randn(32,12,96,72, device='cuda'))[0].shape

torch.Size([32, 12, 96, 72])

In [None]:
plot_img(imgs[-1].permute(1,2,0))

<h1>DiT</h1>

In [3]:
import torch
from torch import nn
import torch.nn.functional as F

from nosaveddata import *


model = DiT_Transformer(128, 8, 8, 108).cuda()

X = torch.randn(16,108,128).cuda()
c = torch.randn(16,128).cuda()

model(X,c).shape

DiT Transformer Parameters: 2.38M


torch.Size([16, 108, 128])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from nosaveddata import *

import numpy as np


def modulate(x, shift, scale):
    return x * (1 + scale[:,None]) + shift[:,None]
    
class DiT_Block(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.0, bias=False, ffn_mult=4):
        super().__init__()
        self.ln_1 = LayerNormNoBias(d_model, bias=bias)
        self.attn = Attention(d_model, num_heads, bias, dropout)
        self.ln_2 = LayerNormNoBias(d_model, bias=bias)
        self.mlp = FFN(d_model, dropout, bias, ffn_mult)
        
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(d_model, 6 * d_model, bias=True)
        )
        self.adaLN_modulation.apply(init_zeros)
        
    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x_ln = modulate(self.ln_1(x), shift_msa, scale_msa)
        x = x + gate_msa[:,None] * self.attn(x_ln, x_ln, x_ln, is_causal=False)
        x = x + gate_mlp[:,None] * self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp))
        return x
    
    
class DiT_Transformer(nn.Module):
    def __init__(self, d_model, num_blks, nhead, seq_len,
                 dropout = 0.1, bias=False, report_params_count=True,
                 ffn_mult=4):
        super().__init__()
        self.num_hiddens = d_model

        self.pos_encoding = nn.Sequential(nn.Linear(seq_len, d_model, bias=False),
                                          LayerNormNoBias(d_model)) #Stable Embedding Layer
        
        self.final_ln = LayerNormNoBias(d_model)
        self.start_dropout = nn.Dropout(dropout)
        self.seq_len = seq_len

        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module("block"+str(i), DiT_Block(
                                d_model, nhead, dropout, bias=False, ffn_mult=ffn_mult))
            
        
        #nn.init.xavier_uniform_(self.pos_encoding[0].weight)
        
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * num_blks))
        
        if report_params_count:
            params_to_count = [p for p in self.parameters() if p.requires_grad]
            print(f'DiT Transformer Parameters: {sum(p.numel() for p in params_to_count)/1e6:.2f}M')
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            #torch.nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            #torch.nn.init.xavier_normal_(module.weight)

        
    def forward(self, X, c):
        # Input:
        # X e (B, T, D)
        # c e (B, D)
        
        pos = torch.arange(0, self.seq_len, dtype=torch.float32, device='cuda')
        pos_emb = self.pos_encoding(pos)
        X = self.start_dropout(X+pos_emb)

        for i, blk in enumerate(self.blks):
            X = blk(X, c)
            
        return self.final_ln(X)

In [None]:
model = DiT_Transformer(512, 8, 8, 128).cuda()

X = torch.randn(16,128,512).cuda()
c = torch.randn(16,512).cuda()

model(X,c).shape