In [None]:
a = [1, 2, 3, 4]

In [None]:
b = [*a, 5]

In [None]:
b

In [None]:
len(list())

In [None]:
7470+830+1670

In [None]:
import torch

In [None]:
w = torch.load('bin/teachers/best_vit_base16_photo.pth')

In [None]:
w.keys()

In [None]:
w_new = {}

for i in w:
    w_new[i[7:]] = w[i]

In [None]:
w_new.keys()

In [None]:
import torch
import random
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class MixStyle(nn.Module):
    """MixStyle.
    Reference:
      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
    """

    def __init__(self, p=0.5, alpha=0.1, eps=1e-6):
        """
        Args:
          p (float): probability of using MixStyle.
          alpha (float): parameter of the Beta distribution.
          eps (float): scaling parameter to avoid numerical issues.
        """
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
        self.eps = eps
        self.alpha = alpha

        self._activated = True

    def __repr__(self):
        return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})'

    def set_activation_status(self, status=True):
        self._activated = status

    def forward(self, x):
        if not self.training or not self._activated:
            return x

        if random.random() > self.p:
            return x

        B = x.size(0)

        mu = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True)
        sig = (var + self.eps).sqrt()
        mu, sig = mu.detach(), sig.detach()
        x_normed = (x-mu) / sig

        lmda = self.beta.sample((B, 1, 1, 1))
        lmda = lmda.to(x.device)

        perm = torch.randperm(B)
        mu2, sig2 = mu[perm], sig[perm]
        mu_mix = mu*lmda + mu2 * (1-lmda)
        sig_mix = sig*lmda + sig2 * (1-lmda)

        return x_normed*sig_mix + mu_mix


class MixStyle2(nn.Module):
    """MixStyle (w/ domain prior).
    The input should contain two equal-sized mini-batches from two distinct domains.
    Reference:
      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
    """

    def __init__(self, p=0.5, alpha=0.1, eps=1e-6):
        """
        Args:
          p (float): probability of using MixStyle.
          alpha (float): parameter of the Beta distribution.
          eps (float): scaling parameter to avoid numerical issues.
        """
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
        self.eps = eps
        self.alpha = alpha

        self._activated = True

    def __repr__(self):
        return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})'

    def set_activation_status(self, status=True):
        self._activated = status

    def forward(self, x):
        """
        For the input x, the first half comes from one domain,
        while the second half comes from the other domain.
        """
        if not self.training or not self._activated:
            return x

        if random.random() > self.p:
            return x

        B = x.size(0)

        mu = x.mean(dim=[2], keepdim=True)
        var = x.var(dim=[2], keepdim=True)
        sig = (var + self.eps).sqrt()
        mu, sig = mu.detach(), sig.detach()
        x_normed = (x-mu) / sig

        lmda = self.beta.sample((B, 1, 1))
        lmda = lmda.to(x.device)

        perm = torch.arange(B - 1, -1, -1) # inverse index
        perm_b, perm_a = perm.chunk(2)
        perm_b = perm_b[torch.randperm(B // 2)]
        perm_a = perm_a[torch.randperm(B // 2)]
        perm = torch.cat([perm_b, perm_a], 0)

        mu2, sig2 = mu[perm], sig[perm]
        mu_mix = mu*lmda + mu2 * (1-lmda)
        sig_mix = sig*lmda + sig2 * (1-lmda)

        return x_normed*sig_mix + mu_mix
    
    
class Intra_ADR(nn.Module):
    def __init__(self, inp, outp, Norm=None, group=1, stride=1, **kwargs):
        super(Intra_ADR, self).__init__()
        self.E_space = nn.Sequential(
            nn.ConvTranspose1d(inp, outp, kernel_size=2, stride=stride, padding=0, output_padding=0, groups=1,
                               bias=True, dilation=1, padding_mode='zeros'),
            nn.InstanceNorm1d(outp),
            nn.ReLU(inplace=True),
            )
        self.mixstyle = MixStyle(p=.5, alpha=.3)
        
    def cc_kth_p(self, input, kth=0):
        kth = 10
        input = torch.topk(input, kth, dim=1)[0]  # n,k,h,w

        input = input.mean(1, keepdim=True)
        return input

    def forward(self, x):
        print(x.shape)
        print(x.shape)
        branch = self.E_space(x)
        branch2 = branch
        print(branch.shape)

        x_adr = branch
        branch_ = branch #.reshape(branch.size(0), branch.size(1), branch.size(2) * branch.size(3))
        branch = F.softmax(branch_, 2)
        branch_out = self.cc_kth_p(branch)
        return branch_out, branch2, x_adr

In [None]:
m = MixStyle2()

In [None]:
m(torch.rand((4,197,768))).shape

In [None]:
768*4

In [None]:
a = Intra_ADR(768,768)

In [None]:
o = a(torch.rand((32, 768, 197)))

In [None]:
o[0].shape, o[1].shape

In [None]:
from models.vision_transformer import vit_base_patch16_224
import timm

In [None]:
m = vit_base_patch16_224(pretrained=True, num_classes=10).cuda()

In [None]:
with torch.no_grad():
    o = m(torch.rand((4,3,224,224)))

In [None]:
o.shape

In [None]:
import tensorflow as tf

In [None]:
def se_module(x, filters, ratio=2):
    
    avg_pool = tf.keras.layers.GlobalAveragePooling2D()(x)    
    

    avg_pool = tf.keras.layers.Dense(filters//ratio,
                             activation='relu')(avg_pool)

    excitation = tf.keras.layers.Dense(filters, activation='sigmoid')(avg_pool)#
    excitation = tf.keras.layers.Reshape((1,filters))(excitation)
    
    return tf.keras.layers.Multiply()([x, excitation])

def red_module(x, filters, kernel_size):
    x_res = x
    
    x = tf.keras.layers.Conv2D(filters, kernel_size, strides=(2,1), padding='same', activation='relu')(x)
    x_res = tf.keras.layers.Conv2D(filters, 1, strides=(2,1), padding='same', activation='relu')(x_res)
    
    return tf.keras.layers.Add()([x_res, x])

def resa_red_module(x, filters,  kernel_size):
    x_res = x
    
    x = tf.keras.layers.Conv2D(filters, kernel_size, strides=1, padding='same', activation='relu')(x)
    x = se_module(x, filters, ratio=4)
    x = tf.keras.layers.Add()([x, x_res])
    return red_module(x, filters, kernel_size)

def resa_module(x, filters,  kernel_size):
    x_res = x
    
    x = tf.keras.layers.Conv2D(filters, kernel_size, strides=1, padding='same', activation='relu')(x)
    x = tf.keras.layers.Add()([x, x_res])
    return x


def build_model(input_shape, filters = 64, kernel_size = (3,1), n_modules =2):
    input_tensor = tf.keras.layers.Input(shape=input_shape)
    
    # first compression                 
    x = tf.keras.layers.Conv2D(filters, (3,1), strides=1, padding='same', activation='relu')(input_tensor)
    
    # main corpus
    for i in range(n_modules):
        x = resa_red_module(x, filters=filters, kernel_size=kernel_size)
        
    
    # prediction
    x_reg = tf.keras.layers.Flatten()(x)
    x_reg = tf.keras.layers.Dropout(0.2)(x_reg)
    output_tensor = tf.keras.layers.Dense(3, activation='linear')(x_reg)
    
    #x_class = tf.keras.layers.Flatten()(x)
    #x_class = tf.keras.layers.Dropout(0.2)(x_class)
    #output_class = tf.keras.layers.Dense(1, activation='sigmoid')(x_class)
    
    
    return tf.keras.Model(input_tensor, output_tensor)

In [None]:
INPUT_SIZE = (10,8,1) # T,F,1

model = build_model(INPUT_SIZE)

model.compile(loss='mean_absolute_error', optimizer=tf.keras.optimizers.Adam(learning_rate = 0.00007))

model.summary()
