In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#Runtime -> Change runtime type-> GPU
! nvidia-smi

Sat Jan  8 14:52:46 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   50C    P8    30W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
# code
!git clone https://github.com/SafwenNaimi/BYOL-ViT.git
!git clone https://github.com/SafwenNaimi/BYOL.git

Cloning into 'BYOL-ViT'...
remote: Enumerating objects: 69, done.[K
remote: Counting objects: 100% (69/69), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 69 (delta 37), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (69/69), done.
Cloning into 'BYOL'...
remote: Enumerating objects: 85, done.[K
remote: Counting objects: 100% (85/85), done.[K
remote: Compressing objects: 100% (84/84), done.[K
remote: Total 85 (delta 26), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (85/85), done.


### Dataset downloading and preprocessing

In [None]:
# STL dataset download
!git clone https://github.com/mttk/STL10.git # run stl10_input.py

In [None]:
!python /content/drive/MyDrive/BYOL-ViT-Hourglass/STL10/stl10_input.py

In [None]:
!python /content/drive/MyDrive/BYOL-ViT-Hourglass/data_preprocessing.py

### BYOL training

In [None]:
# train BYOL
!python /content/drive/MyDrive/BYOL-ViT-Hourglass/BYOL/Train_BYOL.py

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Step [2/275]:	Loss: 1.5143251419067383
Step [3/275]:	Loss: 2.1860415935516357
Step [4/275]:	Loss: 1.4601454734802246
Step [5/275]:	Loss: 1.5380306243896484
Step [6/275]:	Loss: 1.5735058784484863
Step [7/275]:	Loss: 1.160517930984497
Step [8/275]:	Loss: 1.0158894062042236
Step [9/275]:	Loss: 1.897087812423706
Step [10/275]:	Loss: 0.776131272315979
Step [11/275]:	Loss: 2.1397247314453125
Step [12/275]:	Loss: 1.9997661113739014
Step [13/275]:	Loss: 2.0151329040527344
Step [14/275]:	Loss: 1.8773854970932007
Step [15/275]:	Loss: 1.2764887809753418
Step [16/275]:	Loss: 1.0323164463043213
Step [17/275]:	Loss: 1.667266845703125
Step [18/275]:	Loss: 1.629448652267456
Step [19/275]:	Loss: 1.3195356130599976
Step [20/275]:	Loss: 1.538228988647461
Step [21/275]:	Loss: 1.7641196250915527
Step [22/275]:	Loss: 1.3963326215744019
Step [23/275]:	Loss: 1.4958974123001099
Step [24/275]:	Loss: 1.3914896249771118
Step [25/275]:	Loss: 1.569973

### Hourglass definition (modified)




In [None]:
!pip install einops

Collecting einops
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops
Successfully installed einops-0.3.2


In [None]:
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce, repeat

# helpers

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def pad_to_multiple(tensor, multiple, dim = -1, value = 0):
    seq_len = tensor.shape[dim]
    m = seq_len / multiple
    if m.is_integer():
        return tensor
    remainder = math.ceil(m) * multiple - seq_len
    pad_offset = (0,) * (-1 - dim) * 2
    return F.pad(tensor, (*pad_offset, 0, remainder), value = value)

def cast_tuple(val, depth = 1):
    return val if isinstance(val, tuple) else ((val,) * depth)

# factory

def get_hourglass_transformer(
    dim,
    *,
    depth,
    shorten_factor,
    attn_resampling,
    updown_sample_type,
    **kwargs
):
    assert isinstance(depth, int) or (isinstance(depth, tuple)  and len(depth) == 3), 'depth must be either an integer or a tuple of 3, indicating (pre_transformer_depth, <nested-hour-glass-config>, post_transformer_depth)'
    assert not (isinstance(depth, int) and shorten_factor), 'there does not need to be a shortening factor when only a single transformer block is indicated (depth of one integer value)'

    if isinstance(depth, int):
        return Transformer(dim = dim, depth = depth, **kwargs)

    return HourglassTransformer(dim = dim, depth = depth, shorten_factor = shorten_factor, attn_resampling = attn_resampling, updown_sample_type = updown_sample_type, **kwargs)

# up and down sample classes

class NaiveDownsample(nn.Module):
    def __init__(self, shorten_factor):
        super().__init__()
        self.shorten_factor = shorten_factor

    def forward(self, x):
        return reduce(x, 'b (n s) d -> b n d', 'mean', s = self.shorten_factor)

class NaiveUpsample(nn.Module):
    def __init__(self, shorten_factor):
        super().__init__()
        self.shorten_factor = shorten_factor

    def forward(self, x):
        return repeat(x, 'b n d -> b (n s) d', s = self.shorten_factor)

class LinearDownsample(nn.Module):
    def __init__(self, dim, shorten_factor):
        super().__init__()
        self.proj = nn.Linear(dim * shorten_factor, dim)
        self.shorten_factor = shorten_factor

    def forward(self, x):
        x = rearrange(x, 'b (n s) d -> b n (s d)', s = self.shorten_factor)
        return self.proj(x)

class LinearUpsample(nn.Module):
    def __init__(self, dim, shorten_factor):
        super().__init__()
        self.proj = nn.Linear(dim, dim * shorten_factor)
        self.shorten_factor = shorten_factor

    def forward(self, x):
        x = self.proj(x)
        return rearrange(x, 'b n (s d) -> b (n s) d', s = self.shorten_factor)

# classes

class PreNormResidual(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs) + x

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 8,
        dim_head = 64,
        dropout = 0.,
        causal = False
    ):
        super().__init__()
        self.heads = heads
        self.causal = causal
        self.scale = dim_head ** -0.5
        inner_dim = heads * dim_head

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, context = None, mask = None):
        h, device = self.heads, x.device
        kv_input = default(context, x)

        q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

        q = q * self.scale

        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        mask_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            mask = rearrange(mask, 'b j -> b () () j')
            sim = sim.masked_fill(~mask, mask_value)

        if self.causal:
            i, j = sim.shape[-2:]
            mask = torch.ones(i, j, device = device, dtype = torch.bool).triu_(j - i + 1)
            mask = rearrange(mask, 'i j -> () () i j')
            sim = sim.masked_fill(mask, mask_value)

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h = h)
        return self.to_out(out)

def FeedForward(dim, mult = 4, dropout = 0.):
    return nn.Sequential(
        nn.Linear(dim, dim * mult),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(dim * mult, dim)
    )

# transformer classes

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        causal = False,
        heads = 8,
        dim_head = 64,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.,
        norm_out = False
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNormResidual(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout, causal = causal)),
                PreNormResidual(dim, FeedForward(dim, mult = ff_mult, dropout = ff_dropout))
            ]))

        self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity()

    def forward(self, x, context = None, mask = None):
        for attn, ff in self.layers:
            x = attn(x, context = context, mask = mask)
            x = ff(x)

        return self.norm(x)

class HourglassTransformer(nn.Module):
    def __init__(
        self,
        dim,
        *,
        depth,
        shorten_factor = 2,
        attn_resampling = True,
        updown_sample_type = 'naive',
        heads = 8,
        dim_head = 64,
        causal = False,
        norm_out = False
    ):
        super().__init__()
        assert len(depth) == 3, 'depth should be a tuple of length 3'
        assert updown_sample_type in {'naive', 'linear'}, 'downsample / upsample type must be either naive (average pool and repeat) or linear (linear projection and reshape)'

        pre_layers_depth, valley_depth, post_layers_depth = depth

        if isinstance(shorten_factor, (tuple, list)):
            shorten_factor, *rest_shorten_factor = shorten_factor
        elif isinstance(valley_depth, int):
            shorten_factor, rest_shorten_factor = shorten_factor, None
        else:
            shorten_factor, rest_shorten_factor = shorten_factor, shorten_factor

        transformer_kwargs = dict(
            dim = dim,
            heads = heads,
            dim_head = dim_head
        )

        self.causal = causal
        self.shorten_factor = shorten_factor

        if updown_sample_type == 'naive':
            self.downsample = NaiveDownsample(shorten_factor)
            self.upsample   = NaiveUpsample(shorten_factor)
        elif updown_sample_type == 'linear':
            self.downsample = LinearDownsample(dim, shorten_factor)
            self.upsample   = LinearUpsample(dim, shorten_factor)
        else:
            raise ValueError(f'unknown updown_sample_type keyword value - must be either naive or linear for now')

        self.valley_transformer = get_hourglass_transformer(
            shorten_factor = rest_shorten_factor,
            depth = valley_depth,
            attn_resampling = attn_resampling,
            updown_sample_type = updown_sample_type,
            causal = causal,
            **transformer_kwargs
        )

        self.attn_resampling_pre_valley = Transformer(depth = 1, **transformer_kwargs) if attn_resampling else None
        self.attn_resampling_post_valley = Transformer(depth = 1, **transformer_kwargs) if attn_resampling else None

        self.pre_transformer = Transformer(depth = pre_layers_depth, causal = causal, **transformer_kwargs)
        self.post_transformer = Transformer(depth = post_layers_depth, causal = causal, **transformer_kwargs)
        self.norm_out = nn.LayerNorm(dim) if norm_out else nn.Identity()
        self.s = nn.Sigmoid()

    def forward(self, x, mask = None):
        # b : batch, n : sequence length, d : feature dimension, s : shortening factor

        s, b, n = self.shorten_factor, *x.shape[:2]

        # top half of hourglass, pre-transformer layers

        x = self.pre_transformer(x, mask = mask)

        # pad to multiple of shortening factor, in preparation for pooling

        x = pad_to_multiple(x, s, dim = -2)

        if exists(mask):
            padded_mask = pad_to_multiple(mask, s, dim = -1, value = False)

        # save the residual, and for "attention resampling" at downsample and upsample

        x_residual = x.clone()

        # if autoregressive, do the shift by shortening factor minus one

        if self.causal:
            shift = s - 1
            x = F.pad(x, (0, 0, shift, -shift), value = 0.)

            if exists(mask):
                padded_mask = F.pad(padded_mask, (shift, -shift), value = False)

        # naive average pool

        downsampled = self.downsample(x)

        if exists(mask):
            downsampled_mask = reduce(padded_mask, 'b (n s) -> b n', 'sum', s = s) > 0
        else:
            downsampled_mask = None

        # pre-valley "attention resampling" - they have the pooled token in each bucket attend to the tokens pre-pooled

        if exists(self.attn_resampling_pre_valley):
            if exists(mask):
                attn_resampling_mask = rearrange(padded_mask, 'b (n s) -> (b n) s', s = s)
            else:
                attn_resampling_mask = None

            downsampled = self.attn_resampling_pre_valley(
                rearrange(downsampled, 'b n d -> (b n) () d'),
                rearrange(x, 'b (n s) d -> (b n) s d', s = s),
                mask = attn_resampling_mask
            )

            downsampled = rearrange(downsampled, '(b n) () d -> b n d', b = b)

        # the "valley" - either a regular transformer or another hourglass

        x = self.valley_transformer(downsampled, mask = downsampled_mask)

        valley_out = x.clone()

        # naive repeat upsample

        x = self.upsample(x)

        # add the residual

        x = x + x_residual

        # post-valley "attention resampling"

        if exists(self.attn_resampling_post_valley):
            x = self.attn_resampling_post_valley(
                rearrange(x, 'b (n s) d -> (b n) s d', s = s),
                rearrange(valley_out, 'b n d -> (b n) () d')
            )

            x = rearrange(x, '(b n) s d -> b (n s) d', b = b)

        # bring sequence back to original length, if it were padded for pooling

        x = x[:, :n]

        # post-valley transformers

        x = self.post_transformer(x, mask = mask)

        x = self.norm_out(x)
        #with torch.no_grad():
        x = torch.flatten(x)
        
        x = nn.Linear(512*144,10)(x)
  
        x = self.s(x)

        return x

# main class

class HourglassTransformerLM(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        max_seq_len,
        depth,
        shorten_factor = None,
        heads = 8,
        dim_head = 64,
        attn_resampling = True,
        updown_sample_type = 'naive',
        causal = True
    ):
        super().__init__()
        self.max_seq_len = max_seq_len

        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.transformer = get_hourglass_transformer(
            dim = dim,
            depth = depth,
            shorten_factor = shorten_factor,
            attn_resampling = attn_resampling,
            updown_sample_type = updown_sample_type,
            dim_head = dim_head,
            heads = heads,
            causal = causal,
            norm_out = True
        )

        self.to_logits = nn.Linear(dim, num_tokens)

    def forward(self, x, mask = None):
        device = x.device
        x = self.token_emb(x)
        pos_emb = self.pos_emb(torch.arange(x.shape[-2], device = device))
        x = x + rearrange(pos_emb, 'n d -> () n d')

        x = self.transformer(x, mask = mask)
        return self.to_logits(x)

### Features loading

In [None]:
from torch import nn
from torch import functional as F
from torch import optim

In [None]:
from torchvision.models import resnet50
import torch
import tqdm
from torch.utils.data import DataLoader, Dataset

if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('GPU: ', torch.cuda.get_device_name(0))

else:
    device = torch.device("cpu")
    print('No GPU available')

GPU:  Tesla K80


In [None]:
#from torch.utils.data import Dataset

class Dataset(Dataset):
    
    
    def __init__(self,cfg, annotation_file,data_type='train', \
                 transform=None):
        
        """
        Args:
            image_dir (string):  directory with images
            annotation_file (string):  csv/txt file which has the 
                                        dataset labels
            transforms: The trasforms to apply to images
        """
        
        self.data_path = os.path.join(cfg.data_path,cfg.imgs_dir)
        self.label_path = os.path.join(cfg.data_path,cfg.labels_dir,annotation_file)
        self.transform=transform
        self.pretext = cfg.pretext
        if self.pretext == 'rotation':
            self.num_rot = cfg.num_rot
        self._load_data()

    def _load_data(self):
        '''
        function to load the data in the format of [[img_name_1,label_1],
        [img_name_2,label_2],.....[img_name_n,label_n]]
        '''
        self.labels = pd.read_csv(self.label_path)
        
        self.loaded_data = []
#        self.read_data=[]
        for i in range(self.labels.shape[0]):
            img_name = self.labels['Filename'][i]#os.path.join(self.data_path, self.labels['Category'][i],self.labels['FileName'][i])
            #print(img_name)
            #data.append(io.imread(os.path.join(self.image_dir, self.labels['img_name'][i])))
            label = self.labels['Label'][i]
            img = Image.open(img_name)
            img = img.convert('RGB')
            self.loaded_data.append((img,label,img_name))
            img.load()#This closes the image object or else you will get too many open file error
#            self.read_data.append((img,label))

    def __len__(self):
        return len(self.loaded_data)

    def __getitem__(self, idx):

        idx = idx % len(self.loaded_data)
        img,label,img_name = self.loaded_data[idx]
        img,label = self._read_data(img,label)
        
        return img,label

    def _read_data(self,img,label):
        
            # supervised mode; if in supervised mode define a loader function 
            #that given the index of an image it returns the image and its 
            #categorical label
        img = self.transform(img)
        return img, label


In [None]:
import yaml

class dotdict(dict):
   
        __getattr__ = dict.get
        __setattr__ = dict.__setitem__
        __delattr__ = dict.__delitem__

def load_yaml(config_file,config_type='dict'):
    with open(config_file) as f:
        cfg = yaml.safe_load(f)
        #params = yaml.load(f,Loader=yaml.FullLoader)
        
    if config_type=='object':
          cfg = dotdict(cfg)
    return cfg

In [None]:
device = 'cpu'
class ResNetFeatures(nn.Module):
    
    def __init__(self):
        super(ResNetFeatures, self).__init__()

        #encoder = get_model()

        model = resnet50(pretrained=False)
        model.fc = nn.Linear(in_features=model.fc.in_features,out_features=1000,bias=True)
        encoder = model

        Pkl_Filename = "/content/drive/MyDrive/BYOL-ViT-Hourglass/BYOL/experiments/res50_cct100.pth"  
        
        pretrained_path =  os.path.join('/content/drive/MyDrive/BYOL-ViT-Hourglass/BYOL/experiments/res50_cct100.pth')
        state_dict = torch.load(pretrained_path,map_location=device)
        #print(encoder)
        encoder.load_state_dict(state_dict, strict=False)
        #print(encoder)
        encoder.to(device)
        
        self.feature_extractor = torch.nn.Sequential(*list(encoder.children())[:6]) 
        
    def forward(self, inp):
        # inp: (batch_size, 3, 224, 224)

        out = self.feature_extractor(inp)

        # out: (batch_size, 64, 56, 56)
        return out


In [None]:
import os
import pandas as pd
from PIL import Image

config_path = r'/content/drive/MyDrive/BYOL-ViT-Hourglass/BYOL-ViT/config_sl.yaml'
cfg = load_yaml(config_path,config_type='object') 
annotation_file = 'stl.csv'

import torchvision.transforms as transforms
transform_2 = transforms.Compose([               
                transforms.Resize((cfg.img_sz,cfg.img_sz)),
                transforms.ToTensor(),                
            ])


train_dataset = Dataset(cfg,annotation_file,
                            data_type='train',transform=transform_2)
print('train data load success')

train data load success


In [None]:
from torch.utils.data.dataloader import default_collate
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler
collate_func=default_collate

dataset_size = len(train_dataset)
indices = list(range(dataset_size))
split = int(np.floor(0.15 * dataset_size))
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

from torch.utils.data import DataLoader

dataloader_train = DataLoader(train_dataset, batch_size=cfg.batch_size, 
                                  collate_fn=collate_func,sampler=train_sampler)
dataloader_val = DataLoader(train_dataset, batch_size=cfg.batch_size,
                                  collate_fn=collate_func,sampler=valid_sampler)

### Hourglass training

In [None]:
import torch
# from hourglass_transformer_pytorch import HourglassTransformer => we run modified hourglass above instead

model = HourglassTransformer(
    dim = 12*12,
    shorten_factor = 2,
    depth = (4, 2, 4),
    updown_sample_type = 'linear'
)


In [None]:
# TEST CELL

resnet_features = ResNetFeatures()
lr=0.01
#loss_record = RunningAverage()
model.train()
classes=['airplane','bird','car','cat','gazelle','boat','dog','horse','monkey','truck']
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.2, weight_decay=0.05, nesterov=True)#0.2
{device = 'cpu'
for data, target in tqdm.tqdm(dataloader_train):
     d=resnet_features(data).to(device)
     for x, label in zip(d,target): # for batch
       #for y in x:      # for channel
        y=x
        #y = torch.unsqueeze(y, dim=0)
        y=torch.reshape(y,(1,512,12*12))
        #print(' y.shape',y.shape)
        
        output=model(y).float() # (1, 1024, 512)
        output = torch.unsqueeze(output, dim=0)
        
        label = torch.nn.functional.one_hot(label-1,num_classes=10)
        label = torch.unsqueeze(label, dim=0)
        #label = label.type(torch.float64)
        #print(' label',label)
        #print(' label',label.size())

        #print(' output',output)
        #print(' output.size()',output.size())

        l = criterion(output, label.float())

        optimizer.zero_grad()
        l.backward(retain_graph=True)
        optimizer.step()

In [None]:
from torch.utils.tensorboard import SummaryWriter
logs=os.path.join('/content/drive/MyDrive/BYOL-ViT-Hourglass/logs') #tboard
writer = SummaryWriter(logs + '/vit_resnet_pretrained')

def train(epoch, model, device, dataloader, optimizer, criterion,writer):
    """ Train loop, predict rotations. """
    global iter_cnt
    #progbar = tqdm(total=len(dataloader), desc='Train')

    loss_record = RunningAverage()
    acc_record = RunningAverage()
    correct=0
    total = 0
    model.train()
    classes=['airplane','bird','car','cat','gazelle','boat','dog','horse','monkey','truck']

    for data, target in tqdm.tqdm(dataloader_train):
      d = resnet_features(data)
      for x, label in zip(d.to(device), target.to(device)): # for batch     
  
          x=torch.reshape(x,(1,512,12*12))
          output=model(x).float() # (1, 1024, 512)
          output = torch.unsqueeze(output, dim=0)
          
          label = torch.nn.functional.one_hot(label-1,num_classes=10)
          label = torch.unsqueeze(label, dim=0)

          l = criterion(output, label.float())
          
          # measure accuracy and record loss
          confidence, predicted = output.max(1)
          correct += predicted.eq(label).sum().item()
          #acc = utils.compute_acc(output, label)
          total+=label.size(0)
          acc = correct/total
          
          acc_record.update(100*acc)
          loss_record.update(l.item())

          iter_cnt+=1

          # compute gradient and do optimizer step
          optimizer.zero_grad()
          l.backward(retain_graph=True)
          optimizer.step()
        
          
    LR=optimizer.param_groups[0]['lr']
    writer.add_scalar('train/Loss_epoch', loss_record(), epoch)
    writer.add_scalar('train/Acc_epoch', acc_record(), epoch)
    
    print('Train Epoch: {} LR: {:.4f} Avg Loss: {:.4f}; Avg Acc: {:.4f}'.format(epoch,LR, loss_record(), acc_record()))

    return loss_record,acc_record


In [None]:
def validate(epoch, model, device, dataloader, criterion,writer):
    """ Test loop, print metrics """
    #progbar = tqdm(total=len(dataloader), desc='Val')

    global iter_cnt
    loss_record = RunningAverage()
    acc_record = RunningAverage()
    correct=0
    total=0
    model.eval()
    with torch.no_grad():
      for data, target in tqdm.tqdm(dataloader_train):
          d = resnet_features(data)
          for x, label in zip(d.to(device), target.to(device)): # for batch     
              x=torch.reshape(x,(1,512,12*12))

              output=model(x).float() # (1, 1024, 512)
              output = torch.unsqueeze(output, dim=0)
              
              label = torch.nn.functional.one_hot(label-1,num_classes=10)
              label = torch.unsqueeze(label, dim=0)

              l = criterion(output, label.float())

              # measure accuracy and record loss
              acc = compute_acc(output, label)
      #        acc_record.update(100 * acc[0].item())
              acc_record.update(100*acc[0].item()/data.size(0))
              loss_record.update(l.item())
              #print('val Step: {}/{} Loss: {:.4f} \t Acc: {:.4f}'.format(batch_idx,len(dataloader), loss_record(), acc_record()))
              progbar.set_description('Val (loss=%.4f)' % (loss_record()))
              progbar.update(1)


    writer.add_scalar('validation/Loss_epoch', loss_record(), epoch)
    writer.add_scalar('validation/Acc_epoch', acc_record(), epoch)
    
    return loss_record(),acc_record()

In [None]:
class RunningAverage():
    """A simple class that maintains the running average of a quantity
    
    Example:
    ```
    loss_avg = RunningAverage()
    loss_avg.update(2)
    loss_avg.update(4)
    loss_avg() = 3
    ```
    """
    def __init__(self):
        self.steps = 0
        self.total = 0
    
    def update(self, val):
        self.total += val
        self.steps += 1
    
    def __call__(self):
        return self.total/float(self.steps)

In [None]:
# main training cell

logs=os.path.join('/content/drive/MyDrive/BYOL-ViT-Hourglass/BYOL/experiments/hourglass_log') #tboard

global iter_cnt
iter_cnt=0

train_accs = []
test_accs = []
val=0
epochs = 600

for epoch in range(epochs):
    
    train_loss,train_acc = train(epoch, model, device, dataloader_train, optimizer, criterion,writer)
    val_loss,val_acc = validate(epoch, model, device, dataloader_val, criterion, writer)
    
    val=val+val_acc
    #print(f"Epoch : {epoch+1} - acc: {train_acc:.4f} - loss : {train_loss:.4f}\n")
    train_accs.append(train_acc)

    if epoch % 100 == 0:
        print(f"Saving model at epoch {epoch}")
        torch.save(model.state_dict(), f"/content/drive/MyDrive/BYOL-ViT-Hourglass/BYOL/experiments/Hourglass_res50_cct{epoch}.pth")



vall=val/epochs
writer.add_text('validation accuracy','acc {}'.format(val_acc))
writer.close()