In [1]:
import torch
import torchvision
from torchvision import transforms

import torch.nn as nn

from einops import rearrange, reduce, repeat

import timm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!echo $CUDA_VISIBLE_DEVICES

5


In [3]:
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mshiman[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
wandb.init('Token Fusion', name='my layer + mixing + vit')

In [5]:
from modules.transformer import PatchEmbedding, Transformer
from reformer_pytorch import Reformer
from modules.fnet import FNet

In [6]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_dataset = torchvision.datasets.ImageFolder('/data/ILSVRC/Data/CLS-LOC/train', transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

val_dataset = torchvision.datasets.ImageFolder('/data/ILSVRC/Data/CLS-LOC/val', transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32)

In [7]:
class MixingClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool1 = nn.AdaptiveAvgPool1d(1)
        self.pool2 = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.LazyLinear(100)
    def forward(self, x):
        x1 = self.pool2(x)
        x = torch.transpose(x, 1, 2)
        x2 = self.pool1(x)
        x1 = torch.squeeze(x1, dim=-1)
        x2 = torch.squeeze(x2, dim=-1)
        x = torch.concat([x1, x2], axis=1)
        x = self.classifier(x)
        return x

In [8]:
class LayerTokenFusion(nn.Module):
    def __init__(self, classifier):
        super().__init__()
        self.resnet = timm.create_model('resnet101', pretrained=True, features_only=True)
        for par in self.resnet.parameters():
            par.requires_grad = False
        in_ch = in_ch = [3, 64] + [2 ** i for i in range(8, 12)]

        self.upconvs = nn.ModuleList([nn.ConvTranspose2d(in_channels=ch, out_channels=3, kernel_size=2**(i), stride=2**(i)) for i, ch in enumerate(in_ch)])
        self.embedders = nn.ModuleList([PatchEmbedding(in_channels=3, out_channels=64, patch_size=16) for i, ch in enumerate(in_ch)])
        self.transformers = nn.ModuleList([Transformer(depth=2, dim=64, mlp_ratio=1) for _ in range(6)])
        self.attend = nn.ModuleList([nn.MultiheadAttention(embed_dim=64, num_heads=1, batch_first=True) for _ in range(6)])
        
        self.classifier  = classifier

    def forward(self, x):
        out = self.resnet(x)
        x = self.embedders[0](x)
        for i, v in enumerate(out):
            v = self.upconvs[i + 1](v)
            v = self.embedders[i + 1](v)

            x = self.transformers[i](x)
            #x = torch.concat([x, v], axis=1)
            z, _ = self.attend[i](x, v, v)
            x = x + z
        x = self.transformers[-1](x)
        x = self.classifier(x)
        return x

x = torch.randn((3, 3, 224, 224)).cuda()
model = LayerTokenFusion(classifier = MixingClassifier()).cuda()



In [7]:
class BridgeBlock(nn.Module):
    def __init__(self, single_weight=False):
        super().__init__()
        in_ch = in_ch = [64] + [2 ** i for i in range(8, 12)]
        self.expand = nn.ModuleList([
            nn.ConvTranspose2d(in_channels=in_ch[i], out_channels=3, kernel_size=2**(i+1),
                               stride=2**(i+1)) for i in range(5)
        ])
    def forward(self, x, feature, n_layer):
        out = self.expand[n_layer](feature)
        x = torch.concat([x, out], axis=1)
        return x

class EarlyTokenFusion(nn.Module):
    def __init__(self, classifier):
        super().__init__()
        self.resnet = timm.create_model('resnet101', features_only=True, pretrained=True)
        for par in self.resnet.parameters():
            par.requires_grad = False
        self.bridge = BridgeBlock()
        self.patch_ebmedding = PatchEmbedding(in_channels=18, out_channels=256, patch_size=16)
        self.transformer = Transformer(depth=12, dim=256, mlp_ratio=1)
        self.classifier = classifier

    def forward(self, x):
        out = self.resnet(x)
        for i, v in enumerate(out):
            x = self.bridge(x, v, i)
        x = self.patch_ebmedding(x)
        x = self.transformer(x)
        x = self.classifier(x)
        return x
    
x = torch.randn((3, 3, 224, 224)).cuda()
model = EarlyTokenFusion(classifier=MixingClassifier()).cuda()



In [12]:
class LateTokenFusion(nn.Module):
    def __init__(self, classifier):
        super().__init__()
        self.classifier = classifier
        self.resnet = timm.create_model('resnet101', features_only=True, pretrained=True)
        for par in self.resnet.parameters():
            par.requires_grad = False

        self.embedder1 = PatchEmbedding(in_channels=3, out_channels=256, patch_size=16)
        self.transformer1 = Transformer(depth=8, dim=256)

        self.embedder2 = PatchEmbedding(in_channels=1024, out_channels=1024, patch_size=1, same=True)
        self.transformer2 = Transformer(depth=4, dim=1024)

        self.UpConv = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(1, 4), stride=(1, 4))
    
    def forward(self, x):
        x1 = self.embedder1(x)
        x1 = self.transformer1(x1)
        x1 = torch.unsqueeze(x1, dim=1)
        x1 = self.UpConv(x1)
        x1 = torch.squeeze(x1)

        x2 = self.resnet(x)[-2]
        x2 = self.embedder2(x2)
        x2 = self.transformer2(x2)
        

        x = torch.concat([x1, x2], axis=1)
        #x = torch.transpose(x, 1, 2)
        #print(nn.AvgPool1d(392)(x).size())
        x = self.classifier(x)
        return x
    
class Squeeze(nn.Module):
    def forward(self, x):
        return torch.squeeze(x)
    
x = torch.randn((3, 3, 224, 224)).cuda()
class MixingClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool1 = nn.AdaptiveAvgPool1d(1)
        self.pool2 = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.LazyLinear(100)
    def forward(self, x):
        x1 = self.pool2(x)
        x = torch.transpose(x, 1, 2)
        x2 = self.pool1(x)
        x1 = torch.squeeze(x1)
        x2 = torch.squeeze(x2)
        x = torch.concat([x1, x2], axis=1)
        x = self.classifier(x)
        return x
        
model = LateTokenFusion(classifier = MixingClassifier()).cuda()



In [11]:
model(x).size()

torch.Size([3, 100])

In [9]:
import time
from tqdm import tqdm

def train(model, optimizer, loss_fn, train_dl, val_dl, epochs=100, device='cpu'):

    print('train() called: model=%s, opt=%s(lr=%f), epochs=%d, device=%s\n' % \
          (type(model).__name__, type(optimizer).__name__,
           optimizer.param_groups[0]['lr'], epochs, device))

    history = {} # Collects per-epoch loss and acc like Keras' fit().
    history['loss'] = []
    history['val_loss'] = []
    history['acc'] = []
    history['val_acc'] = []

    start_time_sec = time.time()

    for epoch in range(1, epochs+1):

        # --- TRAIN AND EVALUATE ON TRAINING SET -----------------------------
        model.train()
        train_loss         = 0.0
        num_train_correct  = 0
        num_train_examples = 0

        for batch in tqdm(train_dl):

            optimizer.zero_grad()

            x    = batch[0].to(device)
            y    = batch[1].to(device)
            yhat = model(x)
            loss = loss_fn(yhat, y)

            loss.backward()
            optimizer.step()

            train_loss         += loss.data.item() * x.size(0)
            num_train_correct  += (torch.max(yhat, 1)[1] == y).sum().item()
            num_train_examples += x.shape[0]

        train_acc   = num_train_correct / num_train_examples
        train_loss  = train_loss / len(train_dl.dataset)


        # --- EVALUATE ON VALIDATION SET -------------------------------------
        model.eval()
        val_loss       = 0.0
        num_val_correct  = 0
        num_val_examples = 0

        for batch in tqdm(val_dl):

            x    = batch[0].to(device)
            y    = batch[1].to(device)
            yhat = model(x)
            loss = loss_fn(yhat, y)

            val_loss         += loss.data.item() * x.size(0)
            num_val_correct  += (torch.max(yhat, 1)[1] == y).sum().item()
            num_val_examples += y.shape[0]

        val_acc  = num_val_correct / num_val_examples
        val_loss = val_loss / len(val_dl.dataset)


        if epoch == 1 or epoch % 10 == 0:
          print('Epoch %3d/%3d, train loss: %5.2f, train acc: %5.2f, val loss: %5.2f, val acc: %5.2f' % \
                (epoch, epochs, train_loss, train_acc, val_loss, val_acc))
    
        
        wandb.log({
            'train loss': train_loss,
            'val loss': val_loss,
            'train acc': train_acc,
            'val acc': val_acc
        })
        history['loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        
        #torch.save(model.state_dict(), 'currrr.h5')

    # END OF TRAINING LOOP


    end_time_sec       = time.time()
    total_time_sec     = end_time_sec - start_time_sec
    time_per_epoch_sec = total_time_sec / epochs
    print()
    print('Time total:     %5.2f sec' % (total_time_sec))
    print('Time per epoch: %5.2f sec' % (time_per_epoch_sec))

    return history

In [10]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)

In [None]:
history = train(model, optimizer, criterion, train_loader, val_loader, epochs=24, device='cuda')
#torch.save(model.state_dict(), 'cur2.h5')
#wandb.save('cur2.h5')

train() called: model=LayerTokenFusion, opt=AdamW(lr=0.000100), epochs=24, device=cuda



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1633/1633 [1:05:19<00:00,  2.40s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 580/580 [10:41<00:00,  1.11s/it]


Epoch   1/ 24, train loss:  0.11, train acc:  0.97, val loss:  0.14, val acc:  0.96


 63%|██████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                           | 1031/1633 [37:58<09:24,  1.07it/s]

In [12]:
history

import torch
import torchvision
from torchvision import transforms

import torch.nn as nn

from einops import rearrange, reduce, repeat

import timm
import wandb

wandb.login()
wandb.init('Token Fusion', project='my layer + mixing + vit', resume=True)
from modules.transformer import PatchEmbedding, Transformer
from reformer_pytorch import Reformer
from modules.fnet import FNet
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_dataset = torchvision.datasets.ImageFolder('/data/ILSVRC/Data/CLS-LOC/train', transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

val_dataset = torchvision.datasets.ImageFolder('/data/ILSVRC/Data/CLS-LOC/val', transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
        

In [11]:
model = timm.create_model('resnet101', pretrained=True)
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act2): ReLU(inplace=True)
      (aa): Identity()
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     

In [None]:
model = timm.create_model('resnet101', pretrained=True)

model.eval()
val_loss       = 0.0
num_val_correct  = 0
num_val_examples = 0

for batch in tqdm(val_dl):

    x    = batch[0].to(device)
    y    = batch[1].to(device)
    yhat = model(x)
    loss = loss_fn(yhat, y)

    val_loss         += loss.data.item() * x.size(0)
    num_val_correct  += (torch.max(yhat, 1)[1] == y).sum().item()
    num_val_examples += y.shape[0]

val_acc  = num_val_correct / num_val_examples
val_loss = val_loss / len(val_dl.dataset)


In [11]:
history

{'loss': [0.6135925867340781,
  0.2705157463265378,
  0.2064774487246737,
  0.1559920800612375,
  0.11992709155309428,
  0.09436373843291204,
  0.07545991969960586,
  0.06159499103661947,
  0.052060843378862444,
  0.04702815396277903,
  0.041631026792422976,
  0.03816739026094485,
  0.03589167668153447,
  0.03521862431499343,
  0.030950547608680316,
  0.029794174555466433],
 'val_loss': [0.3060782493714911,
  0.25821791502930513,
  0.21299451223542254,
  0.23188905120210931,
  0.24707616758606393,
  0.24491693581045537,
  0.2516499959808879,
  0.26488420082111447,
  0.24849710881547546,
  0.27997587498539206,
  0.27884319014894055,
  0.2872113626650069,
  0.3004888705318007,
  0.2951320726693032,
  0.2981743334061938,
  0.3143005586222774],
 'acc': [0.8336937799043063,
  0.9203253588516747,
  0.9378564593301435,
  0.9520765550239234,
  0.9620478468899522,
  0.9699521531100479,
  0.975511961722488,
  0.9804784688995215,
  0.9832631578947368,
  0.9850334928229665,
  0.9873397129186603,
 

In [12]:
import os

classes = os.listdir('/data/VasilievDV/ILSVRC/Data/CLS-LOC/train')[:100]

In [18]:
import shutil

for cls in classes:
    os.makedirs('/data/nnshiman/ILSVRC/Data/CLS-LOC/train/' + cls)
    os.makedirs('/data/nnshiman/ILSVRC/Data/CLS-LOC/val/' + cls)
    files = os.listdir('/data/VasilievDV/ILSVRC/Data/CLS-LOC/train/' + cls)
    tr_files = files[:1100]
    val_files = files[1100:]
    for file in tr_files:
        shutil.copyfile('/data/VasilievDV/ILSVRC/Data/CLS-LOC/train/' + cls + '/' + file,
                       '/data/nnshiman/ILSVRC/Data/CLS-LOC/train/' + cls + '/' + file)
    for file in val_files:
        shutil.copyfile('/data/VasilievDV/ILSVRC/Data/CLS-LOC/train/' + cls + '/' + file,
                       '/data/nnshiman/ILSVRC/Data/CLS-LOC/val/' + cls + '/' + file)

In [17]:
!rm -rf /data/nnshiman/