<a href="https://colab.research.google.com/github/KyuhyoJeon/SimSiam/blob/main/SimSiam_PatrickHua.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://github.com/PatrickHua/SimSiam.git

In [None]:
import easydict
import os
from datetime import datetime

args = easydict.EasyDict({
    'image_size':224, 
    'learning_rate':0.3,  # origin = 0.2, spijkervet = 3e-4, patrickhua = 0.3
    'momentum':0, 
    'weight_decay':1.5e-6, 
    'batch_size':256, 
    'num_epochs':100, 
    'resnet_version':'resnet18', 
    'optim':'lars', # 'lars', 'adam', 'sgd' 
    'checkpoint_epochs':10, 
    'dataset_dir':'./datasets', 
    'ckpt_dir':'./ckpt', 
    'log_dir':'./log', 
    'num_workers':8, 
    'nodes':1, 
    'gpus':1, 
    'nr':0, 
    'device':'cuda', 
    'eval':True, 
    'dryrun':True, 
    'debug':False
})

args.log_dir = os.path.join(args.log_dir, 'in-progress_'+datetime.now().strftime('%m%d%H%M%S_'))
if not os.path.exists(args.log_dir):
  os.makedirs(args.log_dir)

if args.dryrun:
  args.image_size=32
  args.num_epochs = 10
  args.batch_size = 256
  args.num_workers = 0
  args.dryrun_subset_size = 100
  args.resnet_version = 'resnet18'

if args.debug:
  args.image_size=32
  args.num_epochs = 1
  args.batch_size = 2
  args.num_workers = 0
  args.debug_subset_size = 8
  args.resnet_version = 'resnet18'

tmp_dir = os.path.join(args.ckpt_dir, f"{args.resnet_version}", f"{args.optim}", f"{datetime.now().strftime('%m%d%H')}")
if not os.path.exists(tmp_dir):
  os.makedirs(tmp_dir)

In [None]:
from torchvision import transforms
from PIL import Image, ImageOps
from torchvision.transforms import GaussianBlur

imagenet_norm = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]]

class BYOL_transform: # Table 6 
    def __init__(self, image_size, normalize=imagenet_norm):

        self.transform1 = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0)), # simclr paper gives the kernel size. Kernel size has to be odd positive number with torchvision
            transforms.ToTensor(),
            transforms.Normalize(*normalize)
        ])
        self.transform2 = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(0.4,0.4,0.2,0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            # transforms.RandomApply([GaussianBlur(kernel_size=int(0.1 * image_size))], p=0.1),
            transforms.RandomApply([transforms.GaussianBlur(kernel_size=image_size//20*2+1, sigma=(0.1, 2.0))], p=0.1),
            transforms.RandomApply([Solarization()], p=0.2),
            
            transforms.ToTensor(),
            transforms.Normalize(*normalize)
        ])


    def __call__(self, x):
        x1 = self.transform1(x) 
        x2 = self.transform2(x) 
        return x1, x2


class Transform_single:
    def __init__(self, image_size, train, normalize=imagenet_norm):
        if train == True:
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0), interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(*normalize)
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(int(image_size*(8/7)), interpolation=Image.BICUBIC), # 224 -> 256 
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Normalize(*normalize)
            ])

    def __call__(self, x):
        return self.transform(x)



class Solarization():
    # ImageFilter
    def __init__(self, threshold=128):
        self.threshold = threshold
    def __call__(self, image):
        return ImageOps.solarize(image, self.threshold)

In [None]:
import torch
import torchvision

train_set = torchvision.datasets.CIFAR10(
    root=args.dataset_dir, 
    train=True, 
    transform=BYOL_transform(image_size=args.image_size), 
    download=True
)

memory_set = torchvision.datasets.CIFAR10(
    root=args.dataset_dir, 
    train=True, 
    transform=Transform_single(image_size=args.image_size, train=False), 
    download=False
)

test_set = torchvision.datasets.CIFAR10(
    root=args.dataset_dir, 
    train=False,
    transform=Transform_single(image_size=args.image_size, train=False), 
    download=False
)

if args.debug:
  train_set = torch.utils.data.Subset(train_set, range(0, args.debug_subset_size))
  train_set.classes = train_set.dataset.classes
  train_set.targets = train_set.dataset.targets
  memory_set = torch.utils.data.Subset(memory_set, range(0, args.debug_subset_size))
  memory_set.classes = memory_set.dataset.classes
  memory_set.targets = memory_set.dataset.targets
  test_set = torch.utils.data.Subset(test_set, range(0, args.debug_subset_size))
  test_set.classes = test_set.dataset.classes
  test_set.targets = test_set.dataset.targets

train_loader = torch.utils.data.DataLoader(
    train_set, 
    shuffle=True, 
    batch_size=args.batch_size,
    drop_last=True,
    pin_memory=True,
    num_workers=args.num_workers
)

memory_loader = torch.utils.data.DataLoader(
    memory_set, 
    shuffle=False, 
    batch_size=args.batch_size, 
    drop_last=True, 
    pin_memory=True,
    num_workers=args.num_workers
)

test_loader = torch.utils.data.DataLoader(
    test_set, 
    shuffle=False, 
    batch_size=args.batch_size, 
    drop_last=True, 
    pin_memory=True, 
    num_workers=args.num_workers
)

Files already downloaded and verified


In [None]:
import copy
import random 
from torch import nn 
import torch.nn.functional as F 
from torchvision import transforms 
from math import pi, cos 
from collections import OrderedDict
HPS = dict(
    max_steps=int(1000. * 1281167 / 4096), # 1000 epochs * 1281167 samples / batch size = 100 epochs * N of step/epoch
    # = total_epochs * len(dataloader) 
    mlp_hidden_size=4096,
    projection_size=256,
    base_target_ema=4e-3,
    optimizer_config=dict(
        optimizer_name='lars', 
        beta=0.9, 
        trust_coef=1e-3, 
        weight_decay=1.5e-6,
        exclude_bias_from_adaption=True),
    learning_rate_schedule=dict(
        base_learning_rate=0.2,
        warmup_steps=int(10.0 * 1281167 / 4096), # 10 epochs * N of steps/epoch = 10 epochs * len(dataloader)
        anneal_schedule='cosine'),
    batchnorm_kwargs=dict(
        decay_rate=0.9,
        eps=1e-5), 
    seed=1337,
)

# def loss_fn(x, y, version='simplified'):
    
#     if version == 'original':
#         y = y.detach()
#         x = F.normalize(x, dim=-1, p=2)
#         y = F.normalize(y, dim=-1, p=2)
#         return (2 - 2 * (x * y).sum(dim=-1)).mean()
#     elif version == 'simplified':
#         return (2 - 2 * F.cosine_similarity(x,y.detach(), dim=-1)).mean()
#     else:
#         raise NotImplementedError

def D(p, z, version='simplified'): # negative cosine similarity
    if version == 'original':
        z = z.detach() # stop gradient
        p = F.normalize(p, dim=1) # l2-normalize 
        z = F.normalize(z, dim=1) # l2-normalize 
        return -(p*z).sum(dim=1).mean()

    elif version == 'simplified':# same thing, much faster. Scroll down, speed test in __main__
        return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
    else:
        raise Exception


class MLP(nn.Module):
    def __init__(self, in_dim):
        super().__init__()

        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, HPS['mlp_hidden_size']),
            nn.BatchNorm1d(HPS['mlp_hidden_size'], eps=HPS['batchnorm_kwargs']['eps'], momentum=1-HPS['batchnorm_kwargs']['decay_rate']),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Linear(HPS['mlp_hidden_size'], HPS['projection_size'])

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

class BYOL(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        self.backbone = backbone
        self.projector = MLP(backbone.output_dim)
        self.online_encoder = nn.Sequential(
            self.backbone,
            self.projector
        )

        self.target_encoder = copy.deepcopy(self.online_encoder)
        self.online_predictor = MLP(HPS['projection_size'])
        # raise NotImplementedError('Please put update_moving_average to training')

    def target_ema(self, k, K, base_ema=HPS['base_target_ema']):
        # tau_base = 0.996 
        # base_ema = 1 - tau_base = 0.996 
        return 1 - base_ema * (cos(pi*k/K)+1)/2 
        # return 1 - (1-self.tau_base) * (cos(pi*k/K)+1)/2 

    @torch.no_grad()
    def update_moving_average(self, global_step, max_steps):
        tau = self.target_ema(global_step, max_steps)
        for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            target.data = tau * target.data + (1 - tau) * online.data
            
    def forward(self, x1, x2):
        f_o, h_o = self.online_encoder, self.online_predictor
        f_t      = self.target_encoder

        z1_o = f_o(x1)
        z2_o = f_o(x2)

        p1_o = h_o(z1_o)
        p2_o = h_o(z2_o)

        with torch.no_grad():
            z1_t = f_t(x1)
            z2_t = f_t(x2)
        
        L = D(p1_o, z2_t) / 2 + D(p2_o, z1_t) / 2 
        return {'loss': L}

In [None]:
from torchvision import models

resnet = eval(f'models.{args.resnet_version}()')
resnet.output_dim = resnet.fc.in_features
resnet.fc = torch.nn.Identity()

model = BYOL(resnet).to(args.device)
model = torch.nn.DataParallel(model)

In [None]:
from torch.optim.optimizer import Optimizer 
# comments from the lead author of byol
# 2. + 3. We follow the same implementation as the one used in SimCLR for LARS. This is indeed a bit 
# different from the one described in the LARS paper and the implementation you attached to your email. 
# In particular as in SimCLR we first modify the gradient to include the weight decay (with beta corresponding 
# to self.weight_decay in the SimCLR code) and then adapt the learning rate by dividing by the norm of this 
# sum, this is different from the LARS pseudo code where they divide by the sum of the norm (instead of the 
# norm of the sum as SimCLR and us are doing). This is done in the SimCLR code by first adding the weight 
# decay term to the gradient and then using this sum to perform the adaptation. We also use a term (usually 
# referred to as trust_coefficient but referred as eeta in SimCLR code) set to 1e-3 to multiply the updates 
# of linear layers.
# Note that the logic "if w_norm > 0 and g_norm > 0 else 1.0" is there to tackle numerical instabilities.
# In general we closely followed SimCLR implementation of LARS.
class LARS(Optimizer):
    def __init__(self, 
                 named_modules, 
                 lr,
                 momentum=0.9, # beta? YES
                 trust_coef=1e-3,
                 weight_decay=1.5e-6,
                exclude_bias_from_adaption=True):
        '''byol: As in SimCLR and official implementation of LARS, we exclude bias # and batchnorm weight from the Lars adaptation and weightdecay'''
        defaults = dict(momentum=momentum,
                lr=lr,
                weight_decay=weight_decay,
                 trust_coef=trust_coef)
        parameters = self.exclude_from_model(named_modules, exclude_bias_from_adaption)
        super(LARS, self).__init__(parameters, defaults)

    @torch.no_grad() 
    def step(self):
        for group in self.param_groups: # only 1 group in most cases 
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            lr = group['lr']

            trust_coef = group['trust_coef']
            # print(group['name'])
            # eps = group['eps']
            for p in group['params']:
                # breakpoint()
                if p.grad is None:
                    continue
                global_lr = lr
                velocity = self.state[p].get('velocity', 0)  
                # if name in self.exclude_from_layer_adaptation:
                if self._use_weight_decay(group):
                    p.grad.data += weight_decay * p.data 

                trust_ratio = 1.0 
                if self._do_layer_adaptation(group):
                    w_norm = torch.norm(p.data, p=2)
                    g_norm = torch.norm(p.grad.data, p=2)
                    trust_ratio = trust_coef * w_norm / g_norm if w_norm > 0 and g_norm > 0 else 1.0 
                scaled_lr = global_lr * trust_ratio # trust_ratio is the local_lr 
                next_v = momentum * velocity + scaled_lr * p.grad.data 
                update = next_v
                p.data = p.data - update 


    def _use_weight_decay(self, group):
        return False if group['name'] == 'exclude' else True
    def _do_layer_adaptation(self, group):
        return False if group['name'] == 'exclude' else True

    def exclude_from_model(self, named_modules, exclude_bias_from_adaption=True):
        base = [] 
        exclude = []
        for name, module in named_modules:
            if type(module) in [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]:
                # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm)
                for name2, param in module.named_parameters():
                    exclude.append(param)
            else:
                for name2, param in module.named_parameters():
                    if name2 == 'bias':
                        exclude.append(param)
                    elif name2 == 'weight':
                        base.append(param)
                    else:
                        pass # non leaf modules 
        return [{
            'name': 'base',
            'params': base
            },{
            'name': 'exclude',
            'params': exclude
        }] if exclude_bias_from_adaption == True else [{
            'name': 'base',
            'params': base+exclude 
        }]

In [None]:
if args.optim == 'lars':
  optimizer = LARS(
      model.named_modules(), 
      lr=args.learning_rate*args.batch_size/256, 
      momentum=args.momentum, 
      weight_decay=args.weight_decay
  )
elif args.optim == 'adam':
  optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate*args.batch_size/256)
elif args.optim == 'sgd':
  optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9)
else:
  optimizer = LARS(
      model.named_modules(), 
      lr=args.learning_rate*args.batch_size/256, 
      momentum=args.momentum, 
      weight_decay=args.weight_decay
  )

In [None]:
import numpy as np

class LR_Scheduler(object):
    def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False):
        self.base_lr = base_lr
        self.constant_predictor_lr = constant_predictor_lr
        warmup_iter = iter_per_epoch * warmup_epochs
        warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter)
        decay_iter = iter_per_epoch * (num_epochs - warmup_epochs)
        cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter))
        
        self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
        self.optimizer = optimizer
        self.iter = 0
        self.current_lr = 0
    def step(self):
        for param_group in self.optimizer.param_groups:

            if self.constant_predictor_lr and param_group['name'] == 'predictor':
                param_group['lr'] = self.base_lr
            else:
                lr = param_group['lr'] = self.lr_schedule[self.iter]
        
        self.iter += 1
        self.current_lr = lr
        return lr
    def get_lr(self):
        return self.current_lr

In [None]:
lr_scheduler = LR_Scheduler(
    optimizer,
    10, 0*args.batch_size/256, 
    args.num_epochs, args.learning_rate*args.batch_size/256, 0*args.batch_size/256, 
    len(train_loader),
    constant_predictor_lr=True # see the end of section 4.2 predictor
)

In [None]:
import matplotlib
matplotlib.use('Agg') #https://stackoverflow.com/questions/49921721/runtimeerror-main-thread-is-not-in-main-loop-with-matplotlib-and-flask
import matplotlib.pyplot as plt
from collections import OrderedDict
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter

class Plotter(object):
    def __init__(self):
        self.logger = OrderedDict()
    def update(self, ordered_dict):
        for key, value in ordered_dict.items():
            if isinstance(value, Tensor):
                ordered_dict[key] = value.item()
            if self.logger.get(key) is None:
                self.logger[key] = [value]
            else:
                self.logger[key].append(value)

    def save(self, file, **kwargs):
        fig, axes = plt.subplots(nrows=len(self.logger), ncols=1, figsize=(8,2*len(self.logger)))
        fig.tight_layout()
        for ax, (key, value) in zip(axes, self.logger.items()):
            ax.plot(value)
            ax.set_title(key)

        plt.savefig(file, **kwargs)
        plt.close()

class Logger(object):
    def __init__(self, log_dir, tensorboard=True, matplotlib=True):

        self.reset(log_dir, tensorboard, matplotlib)

    def reset(self, log_dir=None, tensorboard=True, matplotlib=True):

        if log_dir is not None: self.log_dir=log_dir 
        self.writer = SummaryWriter(log_dir=self.log_dir) if tensorboard else None
        self.plotter = Plotter() if matplotlib else None
        self.counter = OrderedDict()

    def update_scalers(self, ordered_dict):

        for key, value in ordered_dict.items():
            if isinstance(value, Tensor):
                ordered_dict[key] = value.item()
            if self.counter.get(key) is None:
                self.counter[key] = 1
            else:
                self.counter[key] += 1

            if self.writer:
                self.writer.add_scalar(key, value, self.counter[key])


        if self.plotter: 
            self.plotter.update(ordered_dict)
            self.plotter.save(os.path.join(self.log_dir, 'plotter.svg'))

In [None]:
logger = Logger(tensorboard=True, matplotlib=True, log_dir=args.log_dir)

In [None]:
from tqdm import tqdm

accuracy = 0 
# Start training
global_progress = tqdm(range(0, args.num_epochs), desc=f'Training')
for epoch in global_progress:
    model.train()
    
    local_progress=tqdm(train_loader, desc=f'Epoch {epoch}/{args.num_epochs}', disable=False)
    for idx, ((images1, images2), labels) in enumerate(local_progress):

        model.zero_grad()
        data_dict = model(images1.to(args.device, non_blocking=True), images2.to(args.device, non_blocking=True))
        loss = data_dict['loss'].mean() # ddp
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        data_dict.update({'lr':lr_scheduler.get_lr()})
        
        local_progress.set_postfix(data_dict)
        logger.update_scalers(data_dict)
    
    epoch_dict = {"epoch":epoch, "accuracy":accuracy}
    global_progress.set_postfix(epoch_dict)
    logger.update_scalers(epoch_dict)

# Save checkpoint
model_path = os.path.join(tmp_dir, f"byol_{datetime.now().strftime('%m%d%H%M%S')}.pth") # datetime.now().strftime('%Y%m%d_%H%M%S')
torch.save({
    'epoch': epoch+1,
    'state_dict':model.module.state_dict()
}, model_path)
print(f"Model saved to {model_path}")
with open(os.path.join(args.log_dir, f"checkpoint_path.txt"), 'w+') as f:
    f.write(f'{model_path}')

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m







Epoch 8/10:  85%|████████▍ | 165/195 [03:31<00:38,  1.28s/it, loss=tensor(-0.5662, device='cuda:0', grad_fn=<AddBackward0>), lr=0.265][A[A[A[A[A[A[A[A[A[A[A










Epoch 8/10:  85%|████████▍ | 165/195 [03:32<00:38,  1.28s/it, loss=tensor(-0.5687, device='cuda:0', grad_fn=<AddBackward0>), lr=0.266][A[A[A[A[A[A[A[A[A[A[A










Epoch 8/10:  85%|████████▌ | 166/195 [03:32<00:36,  1.27s/it, loss=tensor(-0.5687, device='cuda:0', grad_fn=<AddBackward0>), lr=0.266][A[A[A[A[A[A[A[A[A[A[A










Epoch 8/10:  85%|████████▌ | 166/195 [03:33<00:36,  1.27s/it, loss=tensor(-0.5676, device='cuda:0', grad_fn=<AddBackward0>), lr=0.266][A[A[A[A[A[A[A[A[A[A[A










Epoch 8/10:  86%|████████▌ | 167/195 [03:34<00:35,  1.27s/it, loss=tensor(-0.5676, device='cuda:0', grad_fn=<AddBackward0>), lr=0.266][A[A[A[A[A[A[A[A[A[A[A










Epoch 8/10:  86%|████████▌ | 167/195 [03:35<00:35,  1

Model saved to ./ckpt/resnet18/lars/021408/byol_0214091941.pth


In [None]:
def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1, hide_progress=False):
    net.eval()
    classes = len(memory_data_loader.dataset.classes)
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=hide_progress):
            feature = net(data.cuda(non_blocking=True))
            feature = F.normalize(feature, dim=1)
            feature_bank.append(feature)
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        # [N]
        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader, desc='kNN', disable=hide_progress)
        for data, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature = net(data)
            feature = F.normalize(feature, dim=1)
            
            pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t)

            total_num += data.size(0)
            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
            test_bar.set_postfix({'Accuracy':total_top1 / total_num * 100})
    return total_top1 / total_num * 100

# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels

In [None]:
accuracy = knn_monitor(model.module.backbone, memory_loader, test_loader, args.device, k=min(200, len(memory_loader.dataset)), hide_progress=False)
print('Accuracy:', accuracy)











Feature extracting:   0%|          | 0/195 [00:00<?, ?it/s][A[A[A[A[A[A[A[A[A[A









Feature extracting:   1%|          | 1/195 [00:00<00:19,  9.77it/s][A[A[A[A[A[A[A[A[A[A









Feature extracting:   1%|          | 2/195 [00:00<00:19,  9.76it/s][A[A[A[A[A[A[A[A[A[A









Feature extracting:   2%|▏         | 4/195 [00:00<00:18, 10.31it/s][A[A[A[A[A[A[A[A[A[A









Feature extracting:   3%|▎         | 6/195 [00:00<00:17, 11.00it/s][A[A[A[A[A[A[A[A[A[A









Feature extracting:   4%|▍         | 8/195 [00:00<00:16, 11.38it/s][A[A[A[A[A[A[A[A[A[A









Feature extracting:   5%|▌         | 10/195 [00:00<00:15, 11.87it/s][A[A[A[A[A[A[A[A[A[A









Feature extracting:   6%|▌         | 12/195 [00:00<00:14, 12.29it/s][A[A[A[A[A[A[A[A[A[A









Feature extracting:   7%|▋         | 14/195 [00:01<00:14, 12.62it/s][A[A[A[A[A[A[A[A[A[A









Feature extracting:   8%|▊      

In [None]:
class AverageMeter():
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.log = []
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.log.append(self.avg)
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def linear_eval(args, eval_from):

    train_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.CIFAR10(
            args.dataset_dir, 
            train=True, 
            transform=Transform_single(args.image_size, train=True)
        ),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        num_workers=args.num_workers
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.CIFAR10(
            args.dataset_dir, 
            train=False,
            transform=Transform_single(args.image_size, train=False)
        ),
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=True,
        pin_memory=True,
        num_workers=args.num_workers
    )


    model = eval(f'models.{args.resnet_version}()')
    model.output_dim = model.fc.in_features
    model.fc = torch.nn.Identity()
    classifier = nn.Linear(in_features=model.output_dim, out_features=10, bias=True).to(args.device)

    assert eval_from is not None
    save_dict = torch.load(eval_from, map_location='cuda')
    msg = model.load_state_dict({k[9:]:v for k, v in save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
    
    # print(msg)
    model = model.to(args.device)
    model = torch.nn.DataParallel(model)

    # if torch.cuda.device_count() > 1: classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
    classifier = torch.nn.DataParallel(classifier)
    # define optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=30*args.batch_size/256, momentum=0.9, weight_decay=0)

    # define lr scheduler
    lr_scheduler = LR_Scheduler(
        optimizer,
        0, 0*args.batch_size/256, 
        30, 30*args.batch_size/256, 0*args.batch_size/256, 
        len(train_loader),
    )

    loss_meter = AverageMeter(name='Loss')
    acc_meter = AverageMeter(name='Accuracy')

    # Start training
    global_progress = tqdm(range(0, 30), desc=f'Evaluating')
    for epoch in global_progress:
        loss_meter.reset()
        model.eval()
        classifier.train()
        local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{30}', disable=True)
        
        for idx, (images, labels) in enumerate(local_progress):

            classifier.zero_grad()
            with torch.no_grad():
                feature = model(images.to(args.device))

            preds = classifier(feature)

            loss = F.cross_entropy(preds, labels.to(args.device))

            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())
            lr = lr_scheduler.step()
            local_progress.set_postfix({'lr':lr, "loss":loss_meter.val, 'loss_avg':loss_meter.avg})

    classifier.eval()
    correct, total = 0, 0
    acc_meter.reset()
    for idx, (images, labels) in enumerate(test_loader):
        with torch.no_grad():
            feature = model(images.to(args.device))
            preds = classifier(feature).argmax(dim=1)
            correct = (preds == labels.to(args.device)).sum().item()
            acc_meter.update(correct/preds.shape[0])
    print(f'Accuracy = {acc_meter.avg*100:.2f}')

In [None]:
if args.eval is not False:
    linear_eval(args, model_path)

NameError: ignored