In [4]:
import shutil, glob, os

for file in glob.glob('a/*'):
    shutil.copy(file, f'b/{file.split(os.sep)[-1]}')

In [14]:
import torch
from torch import nn
from torch.nn import functional as F


support = torch.linspace(-10, 10, 10)

q = torch.zeros(10)
q[7]=2

q = F.softmax(q,-1)

q, (q*support).sum()

(tensor([0.0610, 0.0610, 0.0610, 0.0610, 0.0610, 0.0610, 0.0610, 0.4509, 0.0610,
         0.0610]),
 tensor(2.1658))

In [2]:
import torch
from torch import nn
from torch.nn import functional as F

from nosaveddata import *


def get_patches(x, patch_shape):
    c, (h, w) = x.shape[1], patch_shape
    
    return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).float()

def get_whitening_parameters(patches):
    n,c,h,w = patches.shape
    patches_flat = patches.view(n, -1)
    est_patch_covariance = (patches_flat.T @ patches_flat) / n
    
    eigenvalues, eigenvectors = torch.linalg.eigh(est_patch_covariance, UPLO='U')
    
    return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.T.reshape(c*h*w,c,h,w).flip(0)

def init_whitening_conv(layer, train_set, eps=5e-4):
    patches = get_patches(train_set, patch_shape=layer.weight.data.shape[2:])
    
    eigenvalues, eigenvectors = get_whitening_parameters(patches)
    
    eigenvectors_scaled = eigenvectors / torch.sqrt(eigenvalues + eps)
    layer.weight.data[:] = torch.cat((eigenvectors_scaled, -eigenvectors_scaled))
    layer.weight.requires_grad=False




class IMPALA_Resnet_Whitened(nsd_Module):
    def __init__(self, first_channels=12, scale_width=1, norm=True, init=init_partial_dirac, act=nn.SiLU()):
        super().__init__()
        # lhs 2 is because we use concatenate positive and negative eigenvectors, 3 is the kernel size
        self.whitened_channels = 2 * first_channels * 3**2
        
        self.cnn = nn.Sequential(self.whitened_block(first_channels, 16*scale_width),
                                 self.get_block(16*scale_width, 32*scale_width),
                                 self.get_block(32*scale_width, 32*scale_width, last_relu=True))
        
        self.cnn[0][1].apply(init)
        params_count(self, 'IMPALA ResNet')

    def whitened_block(self, in_hiddens, out_hiddens, last_relu=False):
        
        blocks = nn.Sequential(DQN_Conv(in_hiddens, self.whitened_channels, 3, 1, 1, max_pool=True, act=self.act, norm=self.norm, init=self.init),
                               nn.Conv2d(self.whitened_channels,out_hiddens, 1, padding=0, stride=1),
                               Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init),
                               Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init, out_act=self.act if last_relu else nn.Identity())
                              )
        
        return blocks
    
    def get_block(self, in_hiddens, out_hiddens, last_relu=False):
        
        blocks = nn.Sequential(DQN_Conv(in_hiddens, out_hiddens, 3, 1, 1, max_pool=True, act=self.act, norm=self.norm, init=self.init),
                               Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init),
                               Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init, out_act=self.act if last_relu else nn.Identity())
                              )
        
        return blocks
        
    def forward(self, X):
        return self.cnn(X)


x=torch.randn(500,12,96,72)
IMPALA_Resnet(12,scale_width=4)
network = IMPALA_Resnet_Whitened(12,scale_width=4)

init_whitening_conv(network.cnn[0][0].conv[0], x)





  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


IMPALA ResNet Parameters: 1.56M
IMPALA ResNet Parameters: 1.59M


In [16]:
steps=40000
k=5
sched = 0.95**k * (torch.arange(steps+1) / steps)**3

sched[::500]

tensor([0.0000e+00, 1.5113e-06, 1.2090e-05, 4.0805e-05, 9.6723e-05, 1.8891e-04,
        3.2644e-04, 5.1837e-04, 7.7378e-04, 1.1017e-03, 1.5113e-03, 2.0115e-03,
        2.6115e-03, 3.3203e-03, 4.1470e-03, 5.1006e-03, 6.1902e-03, 7.4250e-03,
        8.8138e-03, 1.0366e-02, 1.2090e-02, 1.3996e-02, 1.6092e-02, 1.8388e-02,
        2.0892e-02, 2.3614e-02, 2.6562e-02, 2.9747e-02, 3.3176e-02, 3.6859e-02,
        4.0805e-02, 4.5023e-02, 4.9522e-02, 5.4311e-02, 5.9400e-02, 6.4797e-02,
        7.0511e-02, 7.6551e-02, 8.2928e-02, 8.9648e-02, 9.6723e-02, 1.0416e-01,
        1.1197e-01, 1.2016e-01, 1.2874e-01, 1.3772e-01, 1.4710e-01, 1.5691e-01,
        1.6714e-01, 1.7780e-01, 1.8891e-01, 2.0047e-01, 2.1250e-01, 2.2500e-01,
        2.3797e-01, 2.5144e-01, 2.6541e-01, 2.7988e-01, 2.9487e-01, 3.1039e-01,
        3.2644e-01, 3.4303e-01, 3.6018e-01, 3.7789e-01, 3.9618e-01, 4.1504e-01,
        4.3449e-01, 4.5454e-01, 4.7520e-01, 4.9647e-01, 5.1837e-01, 5.4091e-01,
        5.6409e-01, 5.8792e-01, 6.1241e-

In [16]:
class LookaheadState:
    def __init__(self, net, steps, k=5):
        self.k=k
        self.net_ema = {k: v.clone() for k, v in net.state_dict().items()}
        self.sched = 0.95**k * (torch.arange(steps+1) / steps)**3

    def update(self, net, step):
        decay = self.sched[step].item()
        if step%self.k==0:
            for ema_param, net_param in zip(self.net_ema.values(), net.state_dict().values()):
                ema_param.lerp_(net_param, 1-decay)
                net_param.copy_(ema_param)
                
    def update_fixed_decay(self, net, decay, step):
        if step%self.k==0:
            for ema_param, net_param in zip(self.net_ema.values(), net.state_dict().values()):
                ema_param.lerp_(net_param, 1-decay)
                net_param.copy_(ema_param)

lookahead_state = LookaheadState(network, 40000)

lookahead_state.update(network, 8)

def Triangle_Scheduler(optimizer, steps, start=0.2, end=0.07, peak=0.23):
    def triangle(steps, start, end, peak):
        xp = torch.tensor([0, int(peak * steps), steps])
        fp = torch.tensor([start, 1, end])
        x = torch.arange(1+steps)
        m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1])
        b = fp[:-1] - (m * xp[:-1])
        indices = torch.sum(torch.ge(x[:, None], xp[None, :]), 1) - 1
        indices = torch.clamp(indices, 0, len(m) - 1)
        return m[indices] * x + b[indices]
    lr_schedule = triangle(steps, start, end, peak)
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: lr_schedule[i])

momentum=0.9
lr = 0.1 / (1+1/(1-momentum))
print(f"{lr}")

optim = torch.optim.SGD(network.parameters(), lr=lr, weight_decay=0.1, momentum=momentum, nesterov=True)

sched = Triangle_Scheduler(optim, 40000)

print(f"{optim.param_groups[0]['lr']}")
for i in range(int(40000*0.23)):
    sched.step()
print(f"{optim.param_groups[0]['lr']}")

total_train_steps=40000


lookahead_state.update(network, 8)

0.00909090909090909
0.001818181830458343
0.00909090880304575


In [35]:
lr = 11.5/1024# * (1 + 1 / (1 - 0.85))

lr, 0.0153/(1 + 1 / (1 - 0.85))*lr

(0.01123046875, 2.2412109374999998e-05)

In [48]:
class Conv(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding='same', bias=False):
        super().__init__(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias)

    def reset_parameters(self):
        super().reset_parameters()
        if self.bias is not None:
            self.bias.data.zero_()
        w = self.weight.data
        torch.nn.init.dirac_(w[:w.size(1)])

def init_partial_dirac(module):
    if type(module) in (nn.Linear, nn.Conv2d, nn.Conv1d, nn.Conv3d):
        w = module.weight.data
        
        nn.init.dirac_(module.weight[:w.shape[1]])
        nn.init.xavier_uniform_(module.weight[w.shape[1]:], gain=1)

        if module.bias is not None:
            nn.init.zeros_(module.bias)
            

model = Conv(16,32)
model.apply(init_partial_dirac)
last_relu=False
act=nn.ReLU()

in_hiddens=16
out_hiddens=32

norm = False
model = nn.Sequential(DQN_Conv(in_hiddens, out_hiddens, 3, 1, 1, max_pool=True, act=act, norm=norm, init=init_partial_dirac),
                               Residual_Block(out_hiddens, out_hiddens, norm=norm, act=act, init=init_partial_dirac),
                               Residual_Block(out_hiddens, out_hiddens, norm=norm, act=act, init=init_partial_dirac, out_act=act if last_relu else nn.Identity()),
                               MLP(512,512)
                              )

In [24]:
import re
from unidecode import unidecode
from phonemizer import phonemize
from phonemizer.backend import EspeakBackend
backend = EspeakBackend('en-us', preserve_punctuation=True, with_stress=True)

def english_cleaners2(text):
  '''Pipeline for English text, including abbreviation expansion. + punctuation + stress'''
  text = unidecode(text)
  print(f"{text}")
  text = text.lower()
    
  print(f"{text}")
  phonemes = backend.phonemize([text], strip=True)[0]
  phonemes = collapse_whitespace(phonemes)
  return phonemes

english_cleaners2("Olá doutor")


Ola doutor
ola doutor


NameError: name 'backend' is not defined

In [20]:
import torch
import torch.nn.functional as F
from torch import nn


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10,30)
    
    def forward(self, x):
        return self.linear(x)

model = Model()
model = torch.compile(model)

x=torch.randn(1,10)
model(x)


  from .autonotebook import tqdm as notebook_tqdm


RuntimeError: Windows not yet supported for torch.compile

In [9]:
import spacy
import epitran

# Carregar o modelo em português
nlp = spacy.load('pt_core_news_sm')

# Texto de exemplo
texto = "O cachorro correu pelo parque."

#texto = "Exemplo de texto para fonemização."

doc = nlp(texto)
for token in doc:
    try:
        fonemas = fonemizar_texto(token.text)
        print(f'{token.text} -> {fonemas}')
    except Exception as e:
        pass


fonemizar_texto(tokens)



O -> o
cachorro -> kɐkoʁo
correu -> koʁew
pelo -> pɛlo
parque -> pɐɾkʷɛ
. -> .


AttributeError: 'list' object has no attribute 'lower'

In [54]:
from nosaveddata import *
import torch
from torch import nn

class Network(nn.Module):
    def __init__(self, len_state, num_quant, num_actions):
        nn.Module.__init__(self)
       
        self.num_quant = num_quant
        self.num_actions = num_actions
       
        self.layer1 = nn.Linear(len_state, 256)
        self.layer2 = nn.Linear(256, num_actions*num_quant)  

    def forward(self, x):
        x = self.layer1(x)
        x = torch.tanh(x)
        x = self.layer2(x)
        return x.view(-1, self.num_actions, self.num_quant)
   
    def select_action(self, state, eps):
        if not isinstance(state, torch.Tensor):
            state = torch.Tensor([state])    
        action = torch.randint(0, 2, (1,))
        if random.random() > eps:
            action = self.forward(state).mean(2).max(1)[1]
        return int(action)
   

eps_start, eps_end, eps_dec = 0.9, 0.1, 500
eps = lambda steps: eps_end + (eps_start - eps_end) * np.exp(-1. * steps / eps_dec)

Z = Network(len_state=8, num_quant=2, num_actions=7)
Ztgt = Network(len_state=8, num_quant=2, num_actions=7)
Ztgt.load_state_dict(Z.state_dict())
tau = torch.Tensor((2 * np.arange(Z.num_quant) + 1) / (2.0 * Z.num_quant)).view(1, -1)

batch_size=3

def huber(x, k=1.0):
    return torch.where(x.abs() < k, 0.5 * x.pow(2), k * (x.abs() - 0.5 * k))

next_states = torch.randn(batch_size,8)
states = next_states + torch.randn(batch_size,8)*0.01

gamma=0.997
rewards=torch.ones(batch_size,1)


theta = Z(states)
print(f"{theta.shape, theta.mean(-1).argmax(-1)}")
theta = theta[np.arange(batch_size), theta.mean(2).max(1)[1]]


Znext = Ztgt(next_states).detach()
Znext_max = Znext[np.arange(batch_size), Znext.mean(2).max(1)[1]]

print(f"{Znext.mean(2).max(1)[1]}")

Ttheta = rewards + gamma  * Znext_max

print(f"{Ttheta.t()[..., None].shape, theta.shape}")

diff = Ttheta.t()[..., None] - theta

loss = huber(diff) * (tau - (diff.detach() < 0).float()).abs()


loss, diff, Ttheta, theta

(torch.Size([3, 7, 2]), tensor([1, 1, 0]))
tensor([1, 1, 0])
(torch.Size([2, 3, 1]), torch.Size([3, 2]))


(tensor([[[0.1260, 0.1263],
          [0.1239, 0.5740],
          [0.1255, 0.0062]],
 
         [[0.2315, 0.3767],
          [0.0653, 0.3695],
          [0.4300, 0.3719]]], grad_fn=<MulBackward0>),
 tensor([[[ 1.0039,  0.5802],
          [ 0.9957,  1.2654],
          [ 1.0019, -0.2223]],
 
         [[ 1.4260,  1.0023],
          [ 0.7229,  0.9927],
          [ 2.2201,  0.9959]]], grad_fn=<SubBackward0>),
 tensor([[0.8626, 1.2847],
         [1.2662, 0.9935],
         [0.5365, 1.7547]]),
 tensor([[-0.1413,  0.2824],
         [ 0.2705,  0.0008],
         [-0.4653,  0.7588]], grad_fn=<IndexBackward0>))

In [7]:
from nosaveddata import *
import torch
from torch import nn

a=torch.arange(2,device='cuda').long()[:,None].repeat_interleave(15,0)

a,torch.zeros(6,1,device='cuda').long()

(tensor([[0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]], device='cuda:0'),
 tensor([[0],
         [0],
         [0],
         [0],
         [0],
         [0]], device='cuda:0'))

In [42]:
import math
import torch

p1 = 0.6697
p2 = 0.6649
n = 10000

def statistical_difference(p1, p2, n):
    
    d=torch.tensor(p1-p2).abs()

    std = 1.65 * math.sqrt((p1*(1-p1) + p2*(1-p2))/n)
    
    difference = torch.tensor([d-std, d+std])
    
    return difference.sort()[0]

print(statistical_difference(0.834, 0.831, 100000))

tensor([0.0002, 0.0058])


In [13]:
from nosaveddata import *
import torch
from torch import nn

class IMPALA_Resnet(nn.Module):
    def __init__(self, first_channels=12, scale_width=1, norm=True, init=init_relu, act=nn.SiLU()):
        super().__init__()
        self.norm=norm
        self.init=init
        self.act =act
        
        self.cnn = nn.Sequential(self.get_block(first_channels, 16*scale_width),
                                 self.get_block(16*scale_width, 32*scale_width),
                                 self.get_block(32*scale_width, 32*scale_width, last_relu=True))
        params_count(self, 'IMPALA ResNet')
    def get_block(self, in_hiddens, out_hiddens, last_relu=False):
        
        blocks = nn.Sequential(DQN_Conv(in_hiddens, out_hiddens, 3, 1, 1, max_pool=True, act=self.act, norm=self.norm, init=self.init),
                               Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init),
                               Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init, out_act=self.act if last_relu else nn.Identity())
                              )
        
        return blocks
        
    def forward(self, X):
        return self.cnn(X)


class IMPALA_YY(nn.Module):
    def __init__(self, first_channels=12, scale_width=1, norm=True, init=init_relu, act=nn.SiLU()):
        super().__init__()
        self.norm=norm
        self.init=init
        self.act =act

        self.yin = self.get_yin(first_channels, 16*scale_width, 32*scale_width)
        
        self.yang = self.get_yang(first_channels, 16*scale_width)
                                 
        self.head = nn.Sequential(self.get_yang(16*scale_width, 32*scale_width),
                                  self.get_yang(32*scale_width, 32*scale_width, last_relu=True))
        
        params_count(self, 'IMPALA ResNet')

    def get_yin(self, in_hiddens, hiddens, out_hiddens):
        blocks = nn.Sequential(DQN_Conv(1, hiddens, 3, 1, 1, max_pool=True, act=self.act, norm=self.norm, init=self.init),
                               Residual_Block(hiddens, hiddens, norm=self.norm, act=self.act, init=self.init),
                               #DQN_Conv(hiddens, out_hiddens, 3, 1, 1, max_pool=True, act=self.act, norm=self.norm, init=self.init),
                               #Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init),
                               #Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init)
                              )
        return blocks          
        
    def get_yang(self, in_hiddens, out_hiddens, last_relu=False):
        
        blocks = nn.Sequential(DQN_Conv(in_hiddens, out_hiddens, 3, 1, 1, max_pool=True, act=self.act, norm=self.norm, init=self.init),
                               Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init),
                               Residual_Block(out_hiddens, out_hiddens, norm=self.norm, act=self.act, init=self.init, out_act=self.act if last_relu else nn.Identity())
                              )
        
        return blocks
    
    def forward(self, X):

        y = self.yin(X[:,-3:].mean(-3)[:,None])
        x = self.yang(X)
        
        X = x*(1-y) + x + y
        
        return self.head(X)

model = IMPALA_Resnet(scale_width=4)
x=torch.randn(32,12,96,72)
model2 = IMPALA_YY(scale_width=4)

model(x).shape, model2(x).shape

IMPALA ResNet Parameters: 1.56M
IMPALA ResNet Parameters: 1.63M


(torch.Size([32, 128, 12, 9]), torch.Size([32, 128, 12, 9]))

In [28]:
import torch
from torch import nn
import torch.nn.functional as F
from nosaveddata import *

seed_np_torch(42)

def network_ema(target_network, new_network, alpha=0.5):
    for (param_name, param_target), param_new  in zip(target_network.cuda().named_parameters(), new_network.parameters()):
        if 'ln' in param_name: #layer norm
            param_target.data = param_new.data.clone()
        else:
            param_target.data = alpha * param_target.data + (1 - alpha) * param_new.data.clone()


class Modeld(nsd_Module):
    def __init__(self):
        super().__init__()

        self.linear = nn.Linear(10,32)
        self.ln = nn.LayerNorm(32)

    def forward(self,X):
        return self.ln(self.linear(X))

m = Modeld().cuda()
m_rand= Modeld().cuda()


optim=torch.optim.AdamW(m.parameters(), lr=1e-4)

for i in range(4000):
    x=torch.randn(1,10).cuda()
    
    loss = m(x).sum()
    loss.backward()
    
    optim.step()
    optim.zero_grad()

network_ema(m,m_rand)

m.ln.weight, m.linear.weight

(Parameter containing:
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([[ 0.1788,  0.1215, -0.1596,  0.0556,  0.0823, -0.0129, -0.1189,  0.1854,
          -0.0022, -0.1650],
         [-0.1079, -0.0155,  0.0935,  0.0209,  0.0326, -0.1374, -0.1405,  0.0014,
           0.1486,  0.0473],
         [-0.1218, -0.0415, -0.1404, -0.0332, -0.0325,  0.0417,  0.1003, -0.1978,
           0.1183, -0.2110],
         [ 0.1376,  0.0622,  0.0658,  0.1490, -0.1540, -0.0291,  0.1021,  0.0194,
          -0.0155, -0.1166],
         [ 0.1413,  0.0467,  0.0852, -0.0416, -0.0986, -0.0094,  0.0798, -0.0597,
          -0.0080,  0.0361],
         [-0.0403, -0.0299, -0.0763, -0.1011, -0.1358, -0.0595, -0.0660,  0.0495,
           0.0058, -0.1400],
         [ 0.1676, -0.0036,  0.1435, -0.1102, -0.0544,  0.0415, -0.0507, -0.1388,
          -0.

In [2]:
from nosaveddata import *
import torch
from torch import nn

model = nn.Linear(10,2).cuda()
model.apply(init_xavier)
model2 = nn.Linear(10,2).cuda()
network_ema(model, model2, 0)
model.apply(init_xavier)

model.weight.data==model2.weight.data

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


tensor([[False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False]],
       device='cuda:0')

<h1>Preprocessing</h1>

In [11]:
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
import os, glob
from nosaveddata import *


import torchvision
from torchvision import transforms

paths = glob.glob('C:/Users/Augusto/Python/PyTorch/RL/mc_data/4/2023_01_09_14_48_09_100636/*.jpg')
path = 'C:/Users/Augusto/Python/PyTorch/RL/mc_data/4/2023_01_09_14_48_09_100636/7,0,0,0,0,0,0,0,0,0,0,0,0,3,0,.jpg'



tfms = transforms.Compose([
                           transforms.Resize((96, 72)),
                           transforms.ToTensor()
                        ])

img = Image.open(path)
imgs=[]
for p in paths:
    imgs.append(tfms(Image.open(p)))
imgs=torch.stack(imgs)

print(imgs.shape)



imgs, augments_applied = preprocess_iwm_no_solarize(imgs)
    


#plt.imshow(img_tfms)
plot_imgs(imgs.permute(0,2,3,1))
augments_applied

FileNotFoundError: [Errno 2] No such file or directory: 'C:/Users/Augusto/Python/PyTorch/RL/mc_data/4/2023_01_09_14_48_09_100636/7,0,0,0,0,0,0,0,0,0,0,0,0,3,0,.jpg'

In [8]:
import torch
from torch import nn
import torch.nn.functional as F

from nosaveddata import *



def gray_scale_stacked(X, p=0.2, stacks=4):
    # Input: Tensor T e (B,C,T,D)
    
    probs = get_img_preprocessing_prob(X.shape[0], p, X.device)
    stacked_probs = probs.repeat_interleave(stacks,0)
    X = X.view(-1,X.shape[1]//stacks,*X.shape[-2:])
    
    gray_img = X.mean(1,keepdim=True).expand(-1,3,-1,-1)
    
    X = (1-stacked_probs)*X + stacked_probs*gray_img
    
    return X.view(X.shape[0]//stacks, -1, *X.shape[-2:]), probs.squeeze()

def gaussian_blur(X, p=0.2, stacks=4, sigma_min=0.1, sigma_max=2):
    # Input: Tensor T e (B,C,T,D)
    
    probs = get_img_preprocessing_prob(X.shape[0], p, X.device)
    tfms = transforms.GaussianBlur(3, (sigma_min, sigma_max))
    
    blurred = tfms(X)
    X = (1-probs)*X + probs*blurred
    
    return X, probs.squeeze()

def solarization_stacked(X, p=0.2, stacks=4):
    # Input: Tensor T e (B,C,T,D)

    probs = get_img_preprocessing_prob(X.shape[0], p, X.device)
    stacked_probs = probs.repeat_interleave(stacks,0)
    
    X = X.view(-1,X.shape[1]//stacks,*X.shape[-2:])
    
    tfms = transforms.RandomSolarize(0,p=1) # This prob is applied over all the batch or no image at all
    
    solarized = tfms(X)
    X = (1-stacked_probs)*X + stacked_probs*solarized
    
    return X.view(X.shape[0]//stacks, -1, *X.shape[-2:]), probs.squeeze()


def preprocess_iwm_stacked(imgs, p=0.2, stacks=4):
    # Applies the same preprocessing for all images in the sequence, but separated by each beach
    augments_applied=[]
    
    imgs, augmented = gray_scale_stacked(imgs, p, stacks)
    augments_applied.append(augmented)
    
    imgs, augmented = gaussian_blur_stacked(imgs, p, stacks)
    augments_applied.append(augmented)
    
    imgs, augmented = solarization_stacked(imgs, p, stacks)
    augments_applied.append(augmented)
    
    augments_applied = torch.stack(augments_applied,1)
    return imgs, augments_applied

preprocess_iwm_stacked(torch.randn(32,12,96,72, device='cuda'))[0].shape

torch.Size([32, 12, 96, 72])

In [None]:
plot_img(imgs[-1].permute(1,2,0))

<h1>DiT</h1>

In [3]:
import torch
from torch import nn
import torch.nn.functional as F

from nosaveddata import *

unet = UNet_DiT_S_4(in_channels=4).cuda()
x=torch.randn(32,4,32,32).cuda()
c=torch.randn(32,384).cuda()
t=torch.randint(0,1000,(32,)).cuda()
unet(x,t).shape

GPT Transformer Parameters: 31.91M


torch.Size([32, 4, 32, 32])

In [3]:
import torch
from torch import nn
import torch.nn.functional as F

from nosaveddata import *


model = DiT_Transformer(128, 8, 8, 108).cuda()

X = torch.randn(16,108,128).cuda()
c = torch.randn(16,128).cuda()

model(X,c).shape

DiT Transformer Parameters: 2.38M


torch.Size([16, 108, 128])

In [None]:
model = DiT_Transformer(512, 8, 8, 128).cuda()

X = torch.randn(16,128,512).cuda()
c = torch.randn(16,512).cuda()

model(X,c).shape