In [1]:
import sys
sys.path.insert(0, "../")

In [2]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=1


In [3]:
from nets.clr_nets import SupConResNet, SupCEResNet
import numpy as np
import torch 
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import torch
import torch.nn as nn
import numpy as np


class LinearProtoMap():
    def __init__(self, dataloaders, net, n_cls = 10):
        self.criterion = nn.CrossEntropyLoss()
        self.dataloaders = dataloaders
        self.net = net
        self.gpu = torch.cuda.is_available()
        self.n_cls = n_cls
        if self.gpu:
            self.net.cuda()

    def data_prototypes(self):
        if hasattr(self.net, "encoder_dim"):
            hid_dim = self.net.encoder_dim
        else:
            hid_dim = self.net.fc.weight.shape[1]

        prototypes = torch.zeros(self.n_cls, hid_dim)
        lab_count = torch.zeros(self.n_cls)
        if self.gpu:
            prototypes = prototypes.cuda()
            lab_count = lab_count.cuda()

        num_cls = self.n_cls

        with torch.inference_mode():
            for dat, labels in self.dataloaders[0]:
                dat, task, labels = self.create_batch(dat, labels)
                _ = self.net(dat, task)
                rep = self.net.rep
                prototypes.index_add_(0, labels, rep)
                lab_count += torch.bincount(labels, minlength=num_cls)

        prototypes = prototypes / lab_count.reshape(-1, 1)
        return prototypes
     
    def euclid_dist(self, proto, rep, euclid = False):
        if euclid:
            n = rep.shape[0]
            k = proto.shape[0]
            rep = rep.unsqueeze(1).expand(n, k, -1)
            proto = proto.unsqueeze(0).expand(n, k, -1)
            logits = -((rep - proto)**2).sum(dim=2)
        else:
            logits = rep @ proto.T
        return logits

    def create_batch(self, dat, labels):
        task = None

        labels = labels.long()
        batch_size = int(labels.size()[0])
        if self.gpu:
            dat = dat.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)

        return dat, task, labels

    def get_prob(self, proto = True):
        all_out = []
        loss, acc, count = 0.0, 0.0, 0.0
        self.net.eval()

        if proto:
            prototypes = self.data_prototypes()

        with torch.inference_mode():
            for dat, labels in self.dataloaders[0]:
                dat, task, labels = self.create_batch(dat, labels)
                batch_size = int(labels.size()[0])

                if proto:
                    _ = self.net(dat, task)
                    rep = self.net.rep
                    out = self.euclid_dist(prototypes, rep)
                else:
                    out = self.net(dat, task)

                loss += (self.criterion(out, labels).item()) * batch_size

                labels = labels.cpu().numpy()
                out = out.cpu().detach()
                all_out.append(torch.nn.functional.softmax(out, dim=1))
                out = out.numpy()

                acc += np.sum(labels == (np.argmax(out, axis=1)))
                count += batch_size

            ret = np.array((acc/count, loss/count))

        all_out = np.concatenate(all_out)

        return all_out, ret

### Prototypes for SupCon

In [7]:
fpath = "../SupContrast/save/SimCLR/cifar10_models/cifar10_resnet18_lr_0.05_decay_0.0001_bsz_256_temp_0.07_cosine_seed_0/model_38.pth"

In [8]:
ckpt = torch.load(fpath)

In [10]:
model = SupConResNet(feat_dim=128)
state_dict = ckpt['model']
if torch.cuda.is_available():
    if torch.cuda.device_count() > 1:
        model.encoder = torch.nn.DataParallel(model.encoder)
    else:
        new_state_dict = {}
        for k, v in state_dict.items():
            k = k.replace("module.", "")
            new_state_dict[k] = v
        state_dict = new_state_dict
    model = model.cuda()
    model.load_state_dict(state_dict)

In [11]:
from utils.data import Cifar10Dataset, Cifar100Dataset

def fetch_dataset(name):
    if name == "cifar10":
        dataset = Cifar10Dataset([0,1,2,3,4,5,6,7,8,9], permute=False)
    elif name == "cifar100":
        dataset = Cifar100Dataset([i for i in range(100)], permute=False)
    return dataset


dataset = fetch_dataset("cifar10")
loaders = dataset.fetch_data_loaders(256, 10, shuf=False)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
runner = LinearProtoMap(loaders, model)
probs, metrics = runner.get_prob()

In [13]:
metrics

array([0.51258   , 2.24490126])

In [18]:
ckpt["model"]["head.2.bias"].shape

256

### Prototype for Supervised CE

In [6]:
fpath = "../SupContrast/save/SupCE/cifar10_models/cifar10_resnet18_lr_0.2_decay_0.0001_bsz_256_cosine_seed_0/model_42.pth"

In [7]:
ckpt = torch.load(fpath)

In [8]:
model = SupCEResNet(num_classes=10)
state_dict = ckpt['model']
if torch.cuda.is_available():
    if torch.cuda.device_count() > 1:
        model.encoder = torch.nn.DataParallel(model.encoder)
    else:
        new_state_dict = {}
        for k, v in state_dict.items():
            k = k.replace("module.", "")
            new_state_dict[k] = v
        state_dict = new_state_dict
    model = model.cuda()
    model.load_state_dict(state_dict)

In [15]:
from utils.data import Cifar10Dataset, Cifar100Dataset

def fetch_dataset(name, tasks = None):
    if name == "cifar10":
        if (tasks == None):
            tasks = [0,1,2,3,4,5,6,7,8,9]
        dataset = Cifar10Dataset(tasks, permute=False)
    elif name == "cifar100":
        if (tasks == None):
            tasks = [i for i in range(100)]
        dataset = Cifar100Dataset(tasks, permute=False)
    return dataset


dataset = fetch_dataset("cifar100", [95,96,97,98,99])
loaders = dataset.fetch_data_loaders(256, 10, shuf=False)

Files already downloaded and verified
Files already downloaded and verified


In [17]:
runner = LinearProtoMap(loaders, model, n_cls= 5)
probs, metrics = runner.get_prob()
metrics

array([0.104    , 4.6391577])

10