<a href="https://colab.research.google.com/github/Hoseung/ATM/blob/main/feature_eval/mini_batch_logistic_regression_evaluator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision

from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [11]:
def get_file_id_by_model(folder_name):
  file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',
             'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',
             'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}
  return file_id.get(folder_name, "Model not found.")

In [12]:
folder_name = 'resnet18_100-epochs_stl10'
file_id = get_file_id_by_model(folder_name)
print(folder_name, file_id)

resnet18_100-epochs_stl10 14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF


In [3]:
def get_stl10_data_loaders(download, n_channel=3, batch_size=256):
    _transform = [transforms.ToTensor()]
    if n_channel == 1:
        _transform = _transform + [transforms.Lambda(lambda x: x.mean(dim=0, keepdim=True))]
    train_dataset = datasets.STL10('./data', split='train', download=download,
                                  transform=transforms.Compose(_transform))

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=10, drop_last=False, shuffle=True)

    test_dataset = datasets.STL10('./data', split='test', download=download,
                                  transform=transforms.Compose(_transform))

    test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=False)
    return train_loader, test_loader

def get_cifar10_data_loaders(download, n_channel=3, shuffle=False, batch_size=256):
    _transform = [transforms.ToTensor()]
    if n_channel == 1:
        _transform = _transform + [transforms.Lambda(lambda x: x.mean(dim=0, keepdim=True))]
    train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.Compose(_transform))

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=10, drop_last=False, shuffle=True)

    test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                  transform=transforms.Compose(_transform))

    test_loader = DataLoader(test_dataset, batch_size=2*batch_size, # why *2?
                            num_workers=10, drop_last=False, shuffle=False)
    return train_loader, test_loader

In [4]:
dataset_name = ['galaxy', 'cifar10', 'stl10'][0]
if dataset_name == 'galaxy':
    tmo_params = {'b': 6.0,  'c': 3.96, 'dl': 9.22, 'dh': 2.45}
    
    import pickle
    
    ddir = "../../tonemap/bf_data/Nair_and_Abraham_2010/"

    fn = ddir + "all_gals.pickle"
    all_gals = pickle.load(open(fn, "rb"))
    all_gals = all_gals[1:] # Why the first galaxy image is NaN?
    good_gids = np.array([gal['img_name'] for gal in all_gals])

    from astrobf.utils.misc import load_Nair
    cat_data = load_Nair(ddir + "catalog/table2.dat")
    # pd dataframe

    cat = cat_data[cat_data['ID'].isin(good_gids)]
    labels = np.digitize(cat['TT'], np.sort(cat['TT'].unique()), right=True)
    
    n_classes = cat['TT'].nunique()
else:
    n_classes = 10
    
arch = 'resnet50'
batch_size = 32
n_channels = 1

In [5]:
import atm
import atm.simclr as simclr
import atm.simclr.resnet as models

if arch == 'resnet18':
    #model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
    model = models.resnet18(pretrained=False, num_classes=n_classes, num_channels=n_channels).to(device)
elif arch == 'resnet50':
    #model = torchvision.models.resnet50(pretrained=False, num_classes=n_classes).to(device)
    model = models.resnet50(pretrained=False, num_classes=n_classes, num_channels=n_channels).to(device)

In [11]:
bare = False

if bare:
    # Bare R18_ch1
    fn_pth = '/home/hoseung/Dropbox/temp/runs/20210930-025624_cifar10_resnet18_1_256/Resnet_ch1_cifar10_bn256_200_199.pth'
else:
    # SimCLR_R50_ch1 
    fn_pth = '/home/hoseung/Dropbox/temp/runs/20210929-192413_cifar10_resnet50_1_256/checkpoint_0200.pth.tar'

In [6]:
#fn_pth = 'checkpoint_0100.pth.tar'
fn_pth = '/home/hoseung/Dropbox/temp/Sep29_04-42-49_lambda/checkpoint_0300.pth.tar' # Resnet50
checkpoint = torch.load(fn_pth, map_location=device)
state_dict = checkpoint['state_dict']

In [7]:
target_word = ['module.backbone', 'backbone'][0]
for k in list(state_dict.keys()):
    if k.startswith(target_word+'.'):
        if k.startswith(target_word) and not k.startswith(target_word+'.fc'):
          # remove prefix
          state_dict[k[len(target_word+'.'):]] = state_dict[k]
    del state_dict[k]

In [8]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']

In [9]:
from atm.loader import TonemapImageDataset
from functools import partial
from astrobf.tmo import Mantiuk_Seidel

def get_simclr_pipeline_transform(size, s=1, n_channels=3):
    """Return a set of data augmentation transformations as described in the SimCLR paper."""
    if n_channels == 3:
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomApply([color_jitter], p=0.8),
                                          transforms.RandomGrayscale(p=0.2),
                                          GaussianBlur(kernel_size=int(0.1 * size)),
                                          transforms.ToTensor()])
    elif n_channels == 1:
        data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                          transforms.RandomHorizontalFlip(),
#                                          transforms.Lambda(lambda x: x.mean(dim=0, keepdim=True)),
                                          transforms.ToTensor()])

    return data_transforms

if dataset_name == 'cifar10':
    train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif dataset_name == 'stl10':
    train_loader, test_loader = get_stl10_data_loaders(download=True)
elif dataset_name == "galaxy":
    train_dataset = TonemapImageDataset(all_gals, partial(Mantiuk_Seidel, **tmo_params),
                                        labels=labels,
                                        train=True, 
                                        transform=transforms.Compose([transforms.RandomResizedCrop(size=128),
                                                                      transforms.ToTensor()]),
                                        #get_simclr_pipeline_transform(128, n_channels=n_channels)
                                        )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        pin_memory=True, drop_last=True)
    
    test_dataset = TonemapImageDataset(all_gals, partial(Mantiuk_Seidel, **tmo_params),
                                        labels=cat['TT'].to_numpy(),
                                        train=False, 
                                        transform=transforms.Compose([transforms.RandomResizedCrop(size=128),
                                                                      transforms.ToTensor()]),
                                       #get_simclr_pipeline_transform(128, n_channels=n_channels)
                                        )
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=2*batch_size, shuffle=True, # why 2*??
        pin_memory=True, drop_last=True)
    
print("Dataset:", dataset_name)

Dataset: galaxy


In [10]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [12]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [13]:
import time
from torch.utils.tensorboard import SummaryWriter
timestr = time.strftime("%Y%m%d-%H%M%S")
log_dir = timestr + f"_{dataset_name}_{arch}_{n_channels}_transfer"
writer = SummaryWriter(log_dir=log_dir)
print(writer.log_dir)

epochs = 100
for epoch in range(epochs):
    top1_train_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(train_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch)
        loss = criterion(logits, y_batch)

        top1 = accuracy(logits, y_batch, topk=(1,))
        top1_train_accuracy += top1[0]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    top1_train_accuracy /= (counter + 1)
    top1_accuracy = 0
    top5_accuracy = 0
    for counter, (x_batch, y_batch) in enumerate(test_loader):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch)

        top1, top5 = accuracy(logits, y_batch, topk=(1,5))
        top1_accuracy += top1[0]
        top5_accuracy += top5[0]

    top1_accuracy /= (counter + 1)
    top5_accuracy /= (counter + 1)

    # write everystep
    writer.add_scalar('loss', loss, global_step=epoch)
    writer.add_scalar('acc/top1', top1_accuracy, global_step=epoch)
    writer.add_scalar('acc/top5', top5_accuracy, global_step=epoch)
    #writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)

    print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}"
            f"\tTop1 Test accuracy: {top1_accuracy.item()}"
            f"\tTop5 test acc: {top5_accuracy.item()}")

20210930-122043_galaxy_resnet50_1_transfer


  lp = np.log10(lum) # L prime


Epoch 0	Top1 Train accuracy 27.086828231811523	Top1 Test accuracy: 4.414848804473877	Top5 test acc: 30.494966506958008


KeyboardInterrupt: 

Make sure the output dimension of the model meets the number of possible labels, which is NOT 10 for Nair dataset. Mismatch will cause the error as follows:  


`loss.backward()  `  
`...  `  
`RuntimeError: cuda runtime error (710) : device-side assert triggered at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:235`

In [None]:
    
# '20210930-044403_stl10_resnet18_3_transfer'

In [20]:
# ResNet이 내꺼

Epoch 0	Top1 Train accuracy 35.14229965209961	Top1 Test accuracy: 42.522403717041016	Top5 test acc: 88.72932434082031
Epoch 1	Top1 Train accuracy 43.14293670654297	Top1 Test accuracy: 44.34397888183594	Top5 test acc: 89.89545440673828
Epoch 2	Top1 Train accuracy 45.027503967285156	Top1 Test accuracy: 45.86511993408203	Top5 test acc: 90.63419342041016
Epoch 3	Top1 Train accuracy 46.600364685058594	Top1 Test accuracy: 46.72966384887695	Top5 test acc: 91.29595184326172
Epoch 4	Top1 Train accuracy 47.303890228271484	Top1 Test accuracy: 47.55974197387695	Top5 test acc: 91.66475677490234
Epoch 5	Top1 Train accuracy 48.06959533691406	Top1 Test accuracy: 48.08249282836914	Top5 test acc: 91.78308868408203
Epoch 6	Top1 Train accuracy 48.61606979370117	Top1 Test accuracy: 48.09455490112305	Top5 test acc: 91.86235809326172
Epoch 7	Top1 Train accuracy 49.043365478515625	Top1 Test accuracy: 48.32892990112305	Top5 test acc: 92.18462371826172
Epoch 8	Top1 Train accuracy 49.37818908691406	Top1 Test acc

In [19]:
loss

tensor(0.8032, device='cuda:0', grad_fn=<NllLossBackward>)

In [20]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [21]:
writer.close()