In [1]:
import time
import numpy as np
import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# from tqdm.notebook import tqdm

In [2]:
from FADAnet.FADAloader import *
from FADAnet.FADAmodule import *

In [3]:
n_epoch = 81  # total number of epoch
n_epoch_pt = 0
batch_size = 256
batch_size_test = 512
lr = 0.001
gamma = 0.0005
n_support = 7

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
np.random.seed(777)
torch.manual_seed(777)
torch.cuda.manual_seed(777)

In [4]:
train_dataloader=mnist_dataloader(batch_size=batch_size,train=True)
test_dataloader=mnist_dataloader(train=False)
real_test_loader = svhn_dataloader(batch_size = batch_size_test, train=False)

classifier=Classifier()
encoder=Encoder()

classifier.to(device)
encoder.to(device)
loss_fn1 = torch.nn.CrossEntropyLoss()
loss_fn2 = torch.nn.CosineEmbeddingLoss()
optimizer=torch.optim.Adam(list(encoder.parameters())+list(classifier.parameters()))

X_t, Y_t = create_target_samples(n=n_support)

Using downloaded and verified file: ./data/SVHN/test_32x32.mat


In [5]:
# for epoch in tqdm(range(n_epoch)):
for epoch in range(n_epoch):
    
    for data,labels in train_dataloader:
        data=data.to(device)
        labels=labels.to(device)
        X_t = X_t.to(device)
        Y_t = Y_t.to(device)
        optimizer.zero_grad()
        
        map_s = encoder(data)
        y_pred=classifier(map_s)
        loss1=loss_fn1(y_pred,labels)
        map_t = encoder(X_t)
        
        loss2 = 0
        means_s = []
        # means_t = []
        for num in range(10):
#             subset = map_t[Y_t == num]
#             means_t.append(torch.mean(subset, dim = 0))
            subset = map_s[labels == num]
            means_s.append(torch.mean(subset, dim = 0))
        for ctr in range(10*n_support):
            num = Y_t[ctr]
            tmp = map_t[ctr] - means_s[num]
            dd = torch.stack([tmp]*9)
            tmp = [(means_s[num] - means_s[i]) for i in range(10) if i != num]
            Cplane = torch.stack(tmp)
            loss2 += loss_fn2(-dd, Cplane, torch.FloatTensor([-1]*9).to(device))
            loss2 += loss_fn2(dd, Cplane, torch.FloatTensor([-1]*9).to(device))
        loss = loss1 + gamma*loss2
        loss.backward()
        optimizer.step()
    if epoch%5 == 0: print("loss1:", loss1.item(), "   loss2:", loss2.item())
        
    if epoch%10 == 0:
        acc=0
        for data,labels in test_dataloader:
            data=data.to(device)
            labels=labels.to(device)
            y_test_pred=classifier(encoder(data))
            acc+=(torch.max(y_test_pred,1)[1]==labels).float().mean().item()
        accuracy=round(acc / float(len(test_dataloader)), 3)
        print("On source domain: Epoch %d/%d  accuracy: %.3f "%(epoch+1,n_epoch,accuracy))
    
    if epoch%20 == 0:
        mapset_f = []
        labelset_f = []
        # count_map = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}
        for data, labels in train_dataloader:
            data = data.to(device)
            labels=labels.to(device)
            map_f = encoder(data)
            mapset_f.append(map_f)
            labelset_f.append(labels)
        map_f = torch.cat(mapset_f[:-1])
        label_f = torch.cat(labelset_f[:-1])

        means_f = []
        for num in range(10):
            subset = map_f[label_f == num]
            means_f.append(torch.mean(subset, dim = 0))
        nume = 0
        deno = 0
        for data, labels in real_test_loader:
            data = data.to(device)
            labels = labels.to(device)
            tmp = encoder(data)

            map_ff = encoder(data)
            for ctr in range(batch_size_test):
#                 num = labels[ctr]
#                 tmp = map_ff[ctr] - means_f[num]
#                 dd_f = torch.stack([tmp]*9)
#                 tmp = [(means_f[num] - means_f[i]) for i in range(10) if i != num]
#                 Cplane_f = torch.stack(tmp)
#                 loss_real = loss_fn2(-dd_f, Cplane_f, torch.FloatTensor([-1]*9).to(device)) + loss_fn2(dd_f, Cplane_f, torch.FloatTensor([-1]*9).to(device))

                others = []
                for j in range(10):
                    num = j
                    tmp = map_ff[ctr] - means_f[num]
                    dd_f = torch.stack([tmp]*9)
                    tmp = [(means_f[num] - means_f[i]) for i in range(10) if i != num]
                    Cplane_f = torch.stack(tmp)
                    loss_f = loss_fn2(-dd_f, Cplane_f, torch.FloatTensor([-1]*9).to(device)) + loss_fn2(dd_f, Cplane_f, torch.FloatTensor([-1]*9).to(device))
                    others.append(loss_f.item())
                # print(min(others))
                if np.argmin(others) == labels[ctr]: 
                    nume+=1
                    # count_map[labels[ctr].item()] += 1
            deno += batch_size_test
            if deno > 5100: break
        print("-------------------------------------------------")
        print("On target domain: Epoch %d/%d  accuracy:"%(epoch+1,n_epoch), nume / deno)
        # print(count_map)
        print("-------------------------------------------------")
        
        

loss1: 1.6307697296142578    loss2: 28.64988899230957
On source domain: Epoch 1/81  accuracy: 0.855 
-------------------------------------------------
On target domain: Epoch 1/81  accuracy: 0.23203125
-------------------------------------------------
loss1: 1.5232528448104858    loss2: 7.257229804992676
loss1: 1.4714350700378418    loss2: 2.1498336791992188
On source domain: Epoch 11/81  accuracy: 0.985 
loss1: 1.4670677185058594    loss2: 1.4105409383773804
loss1: 1.4917410612106323    loss2: 0.9906170964241028
On source domain: Epoch 21/81  accuracy: 0.989 
-------------------------------------------------
On target domain: Epoch 21/81  accuracy: 0.2095703125
-------------------------------------------------
loss1: 1.4611505270004272    loss2: 0.5555048584938049
loss1: 1.46116304397583    loss2: 0.534331202507019
On source domain: Epoch 31/81  accuracy: 0.987 
loss1: 1.4647431373596191    loss2: 0.41368842124938965
loss1: 1.470369815826416    loss2: 0.3746226727962494
On source doma