<a href="https://colab.research.google.com/github/asharakeh/ot-4-ml-reading-group/blob/main/jumbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
Dependances : 
- python (3.8.0)
- numpy (1.19.2)
- torch (1.7.1)
- POT (0.7.0)
- Cuda

command:
python3 train.py

Author : Kilian Fatras (kilian.fatras@mila.quebec)
"""


import torch
import torch.nn.functional as F
import torch.utils.data
import random
import numpy as np
import torch.nn as nn
import itertools

!pip install POT==0.7.0
import ot

import torch.utils.data
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
torch.multiprocessing.set_sharing_strategy('file_system')
from torch.utils.data.sampler import BatchSampler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




# Some util functions to evaluate the networks or to make stratified source minibatches


In [None]:
#-------- Eval function --------

def model_eval(dataloader, model_g, model_f):
    """
    Model evaluation function
    args:
    - dataloader : considered dataset
    - model_g : feature exctrator (torch.nn)
    - model_f : classification layer (torch.nn)
    """
    model_g.eval()
    model_f.eval()
    total_samples =0
    correct_prediction = 0
    with torch.no_grad():
        for img, label in dataloader:
            img = img.to(device)
            label = label.long().to(device)
            gen_output = model_g(img)
            pred = F.softmax(model_f(gen_output), 1)
            correct_prediction += torch.sum(torch.argmax(pred,1)==label)
            total_samples += pred.size(0)
        accuracy = correct_prediction.cpu().data.numpy()/total_samples
    return accuracy



#--------SAMPLER-------

class BalancedBatchSampler(torch.utils.data.sampler.BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_samples for each of the n_classes.
    Returns batches of size n_classes * (batch_size // n_classes)
    Taken from https://github.com/criteo-research/pytorch-ada/blob/master/adalib/ada/datasets/sampler.py
    """

    def __init__(self, labels, batch_size):
        classes = sorted(set(labels.numpy()))
        print(classes)

        n_classes = len(classes)
        self._n_samples = batch_size // n_classes
        if self._n_samples == 0:
            raise ValueError(
                f"batch_size should be bigger than the number of classes, got {batch_size}"
            )

        self._class_iters = [
            InfiniteSliceIterator(np.where(labels == class_)[0], class_=class_)
            for class_ in classes
        ]

        batch_size = self._n_samples * n_classes
        self.n_dataset = len(labels)
        self._n_batches = self.n_dataset // batch_size
        if self._n_batches == 0:
            raise ValueError(
                f"Dataset is not big enough to generate batches with size {batch_size}"
            )
        print("K=", n_classes, "nk=", self._n_samples)
        print("Batch size = ", batch_size)

    def __iter__(self):
        for _ in range(self._n_batches):
            indices = []
            for class_iter in self._class_iters:
                indices.extend(class_iter.get(self._n_samples))
            np.random.shuffle(indices)
            yield indices

        for class_iter in self._class_iters:
            class_iter.reset()

    def __len__(self):
        return self._n_batches
    
    
class InfiniteSliceIterator:
    def __init__(self, array, class_):
        assert type(array) is np.ndarray
        self.array = array
        self.i = 0
        self.class_ = class_

    def reset(self):
        self.i = 0

    def get(self, n):
        len_ = len(self.array)
        # not enough element in 'array'
        if len_ < n:
            print(f"there are really few items in class {self.class_}")
            self.reset()
            np.random.shuffle(self.array)
            mul = n // len_
            rest = n - mul * len_
            return np.concatenate((np.tile(self.array, mul), self.array[:rest]))

        # not enough element in array's tail
        if len_ - self.i < n:
            self.reset()

        if self.i == 0:
            np.random.shuffle(self.array)
        i = self.i
        self.i += n
        return self.array[i : self.i]

# Define the models we will use

In [None]:
   
    
class Classifier2(nn.Module):
    ''' Classifier class'''
    def __init__(self, nclass=None):
        super(Classifier2, self).__init__()
        assert nclass!=None
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc2(x)
        return x
    
    
def weights_init(m):
    ''' Weight init function for layers '''
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.1)
        m.bias.data.fill_(0)
        
        
def call_bn(bn, x):
    ''' call batch norm layer '''
    return bn(x)


class Cnn_generator(nn.Module):
    '''9 layer CNN feature extractor class'''
    def __init__(self, input_channel=1, n_outputs=10, dropout_rate=0.25, momentum=0.1):
        self.momentum = momentum 
        super(Cnn_generator, self).__init__()
        self.c1=nn.Conv2d(input_channel, 32,kernel_size=3, stride=1, padding=1)        
        self.c2=nn.Conv2d(32,32,kernel_size=3, stride=1, padding=1)        
        self.c3=nn.Conv2d(32,64,kernel_size=3, stride=1, padding=1)        
        self.c4=nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1)        
        self.c5=nn.Conv2d(64,128,kernel_size=3, stride=1, padding=1)        
        self.c6=nn.Conv2d(128,128,kernel_size=3, stride=1, padding=1)        
        self.linear1=nn.Linear(128*3*3, 128)
        self.bn1=nn.BatchNorm2d(32)
        self.bn2=nn.BatchNorm2d(32)
        self.bn3=nn.BatchNorm2d(64)
        self.bn4=nn.BatchNorm2d(64)
        self.bn5=nn.BatchNorm2d(128)
        self.bn6=nn.BatchNorm2d(128)
        self.dropout = nn.Dropout2d(dropout_rate)
        
    def forward(self, x):
        h=x
        h=self.c1(h)
        h=F.relu(call_bn(self.bn1, h))
        h=self.c2(h)
        h=F.relu(call_bn(self.bn2, h))
        h=F.max_pool2d(h, kernel_size=2, stride=2)

        h=self.c3(h)
        h=F.relu(call_bn(self.bn3, h))
        h=self.c4(h)
        h=F.relu(call_bn(self.bn4, h))
        h=F.max_pool2d(h, kernel_size=2, stride=2)

        h=self.c5(h)
        h=F.relu(call_bn(self.bn5, h))
        h=self.c6(h)
        h=F.relu(call_bn(self.bn6, h))
        h=F.max_pool2d(h, kernel_size=2, stride=2)

        h = h.view(h.size(0), -1)
        logit=torch.sigmoid(self.linear1(h))
        return logit

# Code the Jumbot and source only methods. Fill the blank !

In [None]:
class Jumbot(object):
    """Jumbot class"""
    def __init__(self, model_g, model_f, n_class, eta1=0.001, eta2=0.0001, tau=1., epsilon=0.1):
        """
        Initialize jumbot method.
        args :
        - model_g : feature exctrator (torch.nn)
        - model_f : classification layer (torch.nn)
        - n_class : number of classes (int)
        - eta_1 : feature comparison coefficient (float)
        - eta_2 : label comparison coefficient (float)
        - tau : marginal coeffidient (float)
        - epsilon : entropic regularization (float)
        """
        self.model_g = model_g   # target model
        self.model_f = model_f
        self.n_class = n_class
        self.eta1 = eta1  # weight for the alpha term
        self.eta2 = eta2 # weight for target classification
        self.tau = tau
        self.epsilon = epsilon
        print('eta1, eta2, tau, epsilon: ', self.eta1, self.eta2, self.tau, self.epsilon)
    
    def fit(self, source_loader, target_loader, test_loader, n_epochs, criterion=nn.CrossEntropyLoss()):
        """
        Run jumbot method.
        args :
        - source_loader : source dataset 
        - target_loader : target dataset
        - test_loader : test dataset
        - n_epochs : number of epochs (int)
        - criterion : source loss (nn)
        
        return:
        - trained model
        """
        target_loader_cycle = itertools.cycle(target_loader)
        optimizer_g = torch.optim.Adam(self.model_g.parameters(), lr=2e-4)
        optimizer_f = torch.optim.Adam(self.model_f.parameters(), lr=2e-4)

        for id_epoch in range(n_epochs):
            self.model_g.train()
            self.model_f.train()
            for i, data in enumerate(source_loader):
                ### Load data
                xs_mb, ys = data
                xs_mb, ys = xs_mb.cuda(), ys.cuda()
                xt_mb, _ = next(target_loader_cycle)
                xt_mb = xt_mb.cuda()
                
                g_xs_mb = self.model_g(xs_mb.cuda())
                f_g_xs_mb = self.model_f(g_xs_mb)
                g_xt_mb = self.model_g(xt_mb.cuda())
                f_g_xt_mb = self.model_f(g_xt_mb)
                pred_xt = F.softmax(f_g_xt_mb, 1)

                ### loss
                s_loss = criterion(f_g_xs_mb, ys.cuda())

                ###  Ground cost
                #embed_cost = 
                
                ys = F.one_hot(ys, num_classes=self.n_class).float()
                #t_cost = ??
                
                #total_cost = ??

                #OT computation
                #pi = ??
                pi = torch.from_numpy(pi).float().cuda()

                # train the model 
                optimizer_g.zero_grad()
                optimizer_f.zero_grad()

                #da_loss = 
                tot_loss = s_loss + da_loss
                tot_loss.backward()

                optimizer_g.step()
                optimizer_f.step()
            
            print('epoch, loss : ', id_epoch, s_loss.item(), da_loss.item())
            if id_epoch%10 == 0:
                source_acc = self.evaluate(source_loader)
                target_acc = self.evaluate(test_loader)
                print('source and test accuracies : ', source_acc, target_acc)
        
        return tot_loss

    def source_only(self, source_loader, criterion=nn.CrossEntropyLoss(), n_epochs=10):
        """
        Run source only.
        args :
        - source_loader : source dataset 
        - criterion : source loss (nn)
        
        return:
        - trained model
        """
        optimizer_g = torch.optim.Adam(self.model_g.parameters(), lr=2e-4)
        optimizer_f = torch.optim.Adam(self.model_f.parameters(), lr=2e-4)

        for id_epoch in range(n_epochs):
            self.model_g.train()
            self.model_f.train()
            for i, data in enumerate(source_loader):
                ### Load data
                xs_mb, ys = data
                xs_mb, ys = xs_mb.cuda(), ys.cuda()
                
                g_xs_mb = self.model_g(xs_mb.cuda())
                f_g_xs_mb = self.model_f(g_xs_mb)

                ### loss
                s_loss = criterion(f_g_xs_mb, ys.cuda())

                # train the model 
                optimizer_g.zero_grad()
                optimizer_f.zero_grad()

                tot_loss = s_loss
                tot_loss.backward()

                optimizer_g.step()
                optimizer_f.step()
        
        return tot_loss
    

    def evaluate(self, data_loader):
        score = model_eval(data_loader, self.model_g, self.model_f)
        return score


# Create the datasets

In [None]:
batch_size = 500
nclass = 10
np.random.seed(1980)

# pre-processing to tensor, and mean subtraction


######DATASETS
### TRAIN sets
transform_usps = transforms.Compose([
                    transforms.Resize(28),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,))
                ])

train_usps_trainset = datasets.USPS('./data', train=True, download=True,
                                transform=transform_usps)

print('nb source data : ', len(train_usps_trainset))

source_data = torch.zeros((len(train_usps_trainset), 1, 28, 28))
source_labels = torch.zeros((len(train_usps_trainset)))

for i, data in enumerate(train_usps_trainset):
    source_data[i] = data[0]
    source_labels[i] = data[1]

train_batch_sampler = BalancedBatchSampler(source_labels, batch_size=batch_size)
train_usps_loader = torch.utils.data.DataLoader(train_usps_trainset, batch_sampler=train_batch_sampler)

transform_mnist = transforms.Compose([
                    transforms.Resize(28),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5), (0.5))
                ])

train_mnist_trainset = datasets.MNIST('./data', train=True, download=True,
                                    transform=transform_mnist)
train_mnist_loader = torch.utils.data.DataLoader(train_mnist_trainset, batch_size=batch_size, shuffle=True)


### TEST sets

test_usps_loader = torch.utils.data.DataLoader(
        datasets.USPS('./data', train=False, transform=transform_usps, download=True),
        batch_size=batch_size, shuffle=False)

test_mnist_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, download=True, transform=transform_mnist),
        batch_size=batch_size, shuffle=False)

nb source data :  7291
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
K= 10 nk= 50
Batch size =  500


# Train the network on the source only and compare the performances

In [None]:
eta1 = 0.1
eta2 = 0.1
tau = 1.0
epsilon = 0.1

model_g = Cnn_generator().cuda().apply(weights_init)
model_f = Classifier2(nclass=nclass).cuda().apply(weights_init)

model_g.train()
model_f.train()

jumbot = Jumbot(model_g, model_f, n_class=nclass, eta1=eta1, eta2=eta2, tau=tau, epsilon=epsilon)
loss = jumbot.source_only(train_usps_loader, n_epochs=50)

source_acc =jumbot.evaluate(test_usps_loader)
target_acc =jumbot.evaluate(test_mnist_loader)

print("source_acc = {}, target_acc ={}".format(source_acc, target_acc))

eta1, eta2, tau, epsilon:  0.1 0.1 1.0 0.1
source_acc = 0.9735924265072247, target_acc =0.2182


# Train the network with JUMBOT

In [None]:
model_g = Cnn_generator().cuda().apply(weights_init)
model_f = Classifier2(nclass=nclass).cuda().apply(weights_init)

model_g.train()
model_f.train()

jumbot = Jumbot(model_g, model_f, n_class=nclass, eta1=eta1, eta2=eta2, tau=tau, epsilon=epsilon)
loss = jumbot.source_only(train_usps_loader)
loss = jumbot.fit(train_usps_loader, train_mnist_loader, test_mnist_loader, n_epochs=100)

source_acc =jumbot.evaluate(test_usps_loader)
target_acc =jumbot.evaluate(test_mnist_loader)

print("source_acc = {}, target_acc ={}".format(source_acc, target_acc))

eta1, eta2, tau, epsilon:  0.1 0.1 1.0 0.1
epoch, loss :  0 0.15685325860977173 0.6192663908004761
source and test accuracies :  0.9775714285714285 0.8857
epoch, loss :  1 0.11291591078042984 0.40434613823890686
epoch, loss :  2 0.10371321439743042 0.28302329778671265
epoch, loss :  3 0.07542509585618973 0.22587323188781738
epoch, loss :  4 0.06522281467914581 0.20267876982688904
epoch, loss :  5 0.05788493528962135 0.15609566867351532
epoch, loss :  6 0.0433298796415329 0.1373007893562317
epoch, loss :  7 0.04535941779613495 0.1276775747537613
epoch, loss :  8 0.03422129154205322 0.10809904336929321
epoch, loss :  9 0.030514037236571312 0.09756715595722198
epoch, loss :  10 0.033875081688165665 0.09941273182630539
source and test accuracies :  0.968 0.9629
epoch, loss :  11 0.03565165773034096 0.11504471302032471
epoch, loss :  12 0.022952986881136894 0.0845676138997078
epoch, loss :  13 0.01983393356204033 0.08031203597784042
epoch, loss :  14 0.023547692224383354 0.0669655129313469
