In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as nnf
import numpy as np
import pandas as pd
import timm
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset
from torch.utils.data import Dataset
from torchvision import transforms as T
from torch.nn.functional import interpolate

In [None]:
cifar10_train = datasets.CIFAR10("../data", train=True, download=True, transform=T.ToTensor())
cifar10_test = datasets.CIFAR10("../data", train=False, download=True, transform=T.ToTensor())
my_model = timm.create_model('resnet18',pretrained = False).cuda()
num_ftrs = my_model.fc.in_features
my_model.fc = nn.Linear(num_ftrs, 2)

train_cats=torch.cat([cifar10_train[i][0][None,:,:,:] for i in range(50000) if cifar10_train[i][1]==3])
train_dogs=torch.cat([cifar10_train[i][0][None,:,:,:] for i in range(50000) if cifar10_train[i][1]==5])[:1000]
test_cats=torch.cat([cifar10_test[i][0][None,:,:,:] for i in range(10000) if cifar10_test[i][1]==3])
test_dogs=torch.cat([cifar10_test[i][0][None,:,:,:] for i in range(10000) if cifar10_test[i][1]==5])

In [None]:
class BinaryDataset(Dataset):
    def __init__(self, class1, class2, transform=None, target_transform=None):
        self.imgs = torch.cat([class1,class2])
        self.img_labels = [0 for _ in range(len(class1))]+[1 for _ in range(len(class2))]
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        if self.transform:
            image = self.transform(self.imgs[idx])
        if self.target_transform:
            label = self.target_transform(self.img_labels[idx])
        return image, label
    
mean = torch.reshape(torch.tensor([0.485, 0.456, 0.406]),(3,1,1)).cuda()
std = torch.reshape(torch.tensor([0.229, 0.224, 0.225]),(3,1,1)).cuda()
transform = T.Compose([T.Resize([224,224]),T.ToTensor(),T.Normalize(mean,std)])  

def resample(sample,factor):
    n = len(sample)
    if factor>1: #upsample
        extra = int((factor-1)*n)
        new_samples = torch.zeros((extra,)+sample.size()[1:])
        for e in range(extra):
            i = np.random.randint(n)
            new_samples[e] = samples[i]
        return torch.cat([sample,new_samples])
    else: #downsample
        remain, seen = int(factor*n), set()
        remain_samples = torch.zeros((remain,)+sample.size()[1:])
        for r in range(remain):
            i = np.random.randint(n)
            add = 1
            while i in seen:
                i = (i+add)%n
                add *= 2
            seen.add(i)
            remain_samples[r] = sample[i]
        return remain_samples  

train_ds = BinaryDataset(train_cats,train_dogs,transform=transform)
#train_ds = BinaryDataset(resample(train_cats,0.2),train_dogs,transform=transform)
#train_ds = BinaryDataset(train_cats,resample(train_dogs,5),transform=transform)
test_ds = BinaryDataset(test_cats,test_dogs,transform=transform)
train_loader = DataLoader(train_ds, batch_size = 100, shuffle=True)
test_loader = DataLoader(test_ds, batch_size = 100, shuffle=False)

In [None]:
def train_model(model, loader, num_epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=4*1e-3)
    loss_module = nn.CrossEntropyLoss()
    #loss_module = nn.CrossEntropyLoss(weight=torch.tensor([1,5])) #importance weighting:[1,5]

    for epoch in range(num_epochs):
        print("at epoch ", epoch)
        correct = 0
        total = 0
        total_loss = 0
        for X, y in loader:
            X, y = X.cuda(), y.cuda()
            yp = model(X)
            loss = loss_module(yp, y)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()
            correct += (yp.max(dim=1)[1] == y).sum()
            total += len(X)
        print(correct.item()," out of  ", total)
        print("loss is ", total_loss)
            
    return model