In [16]:
import time
import numpy as np
import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# from tqdm.notebook import tqdm

In [17]:
from FADAnet.UMloader import *
from FADAnet.UMmodule import *

In [18]:
# loss version: loss = loss1 + gamma*loss2/(10*n_support) + theta*loss3/(10*n_support)
n_epoch = 161  # total number of epoch
n_epoch_pt = 0
batch_size = 128
batch_size_test = 256
lr = 0.0007
gamma = 0.065
theta = 0.055
n_support = 7
loss3_margin = 0.7

domain_adaptation_task = 'MNIST_to_USPS'
repetition = 0

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
np.random.seed(777)
torch.manual_seed(777)
torch.cuda.manual_seed(777)

In [19]:
train_set = TrainSet(domain_adaptation_task, repetition, n_support)
test_set = TestSet(domain_adaptation_task, repetition, n_support)

train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
real_test_loader = DataLoader(test_set, batch_size=batch_size_test, shuffle=True, drop_last=True)

classifier=Classifier()
encoder=Encoder()
classifier.to(device)
encoder.to(device)

loss_fn1 = torch.nn.CrossEntropyLoss()
loss_fn2 = torch.nn.CosineEmbeddingLoss()
loss_fn3 = torch.nn.CosineEmbeddingLoss(margin=loss3_margin)
optimizer=torch.optim.Adadelta(list(encoder.parameters())+list(classifier.parameters()))

X_t = torch.from_numpy(train_set.x_target).unsqueeze(1)
Y_t = torch.from_numpy(train_set.y_target).long()

Source X :  2000  Y :  2000
Target X :  70  Y :  70
Class P :  14000  N :  126000


In [20]:
# for epoch in tqdm(range(n_epoch)):
for epoch in range(n_epoch):
    
    for data,labels in train_dataloader:
        data=data.to(device)
        labels=labels.to(device)
        # print(labels)
        X_t = X_t.to(device)
        Y_t = Y_t.to(device)
        optimizer.zero_grad()
        
        map_s = encoder(data)
        y_pred=classifier(map_s)
        loss1=loss_fn1(y_pred,labels)
        map_t = encoder(X_t)
        
        loss2 = 0
        loss3 = 0
        means_s = []
        # means_t = []
        
        for num in range(10):
#             subset = map_t[Y_t == num]
#             means_t.append(torch.mean(subset, dim = 0))
            subset = map_s[labels == num]
            if len(subset) <= 1: 
                danger = 1
            means_s.append(torch.mean(subset, dim = 0))
        for ctr in range(10*n_support):
            num = Y_t[ctr]
            dd = torch.stack([ map_t[ctr] - means_s[num] ]*10)
            Cplane = torch.stack( [(means_s[i-1] - means_s[i]) for i in range(10)] )
            loss2 += loss_fn2(-dd, Cplane, torch.FloatTensor([-1]*10).to(device))
            loss2 += loss_fn2(dd, Cplane, torch.FloatTensor([-1]*10).to(device))
            Cplane = torch.stack( [( map_t[ctr] - means_s[i] ) for i in range(10) if i != num] )
            loss3 += loss_fn3(dd[:-1], Cplane, torch.FloatTensor([-1]*9).to(device))
        if torch.isnan(loss1) or torch.isnan(loss2) or torch.isnan(loss3):
            continue
        loss = loss1 + gamma*loss2/(10*n_support) + theta*loss3/(10*n_support)
        loss.backward()
        optimizer.step()
        
    if epoch%5 == 0: print("loss1:", loss1.item(), "   loss2:", loss2.item(), "   loss3:", loss3.item())
        
    if epoch%10 == 0:
        acc=0
        for data,labels in test_dataloader:
            data=data.to(device)
            labels=labels.to(device)
            y_test_pred=classifier(encoder(data))
            acc+=(torch.max(y_test_pred,1)[1]==labels).float().mean().item()
        accuracy=round(acc / float(len(test_dataloader)), 3)
        print("On source domain: Epoch %d/%d  accuracy: %.3f "%(epoch+1,n_epoch,accuracy))
    
    if epoch%20 == 0:
        mapset_f = []
        labelset_f = []
        for data, labels in train_dataloader:
            data = data.to(device)
            labels=labels.to(device)
            map_f = encoder(data)
            mapset_f.append(map_f)
            labelset_f.append(labels)
        map_f = torch.cat(mapset_f[:-1])
        label_f = torch.cat(labelset_f[:-1])

        means_f = []
        for num in range(10):
            subset = map_f[label_f == num]
            means_f.append(torch.mean(subset, dim = 0))
#        nume = 0
        deno = 0
        acc = 0  
        for data, labels in real_test_loader:
            data = data.to(device)
            labels = labels.to(device)        
            map_ff = encoder(data)
            distTS = []
            for ii in range(10):
                distTS.append(torch.norm((map_ff - means_f[ii]), dim=1))
            distTS = torch.stack(distTS)
            acc+=torch.sum(torch.argmin(distTS, dim=0)==labels)
            
#             for ctr in range(batch_size_test):
#                 others = []
#                 for j in range(10):
#                     num = j
#                     tmp = map_ff[ctr] - means_f[num]
#                     dd_f = torch.stack([tmp]*9)
#                     tmp = [(means_f[num] - means_f[i]) for i in range(10) if i != num]
#                     Cplane_f = torch.stack(tmp)
#                     loss_f = loss_fn2(-dd_f, Cplane_f, torch.FloatTensor([-1]*9).to(device)) + loss_fn2(dd_f, Cplane_f, torch.FloatTensor([-1]*9).to(device))
#                     others.append(loss_f.item())
#                 # print(min(others))
#                 if np.argmin(others) == labels[ctr]: 
#                     nume+=1
            deno += len(labels)
        print("-------------------------------------------------")
#        print("On target domain: Epoch %d/%d  accuracy:"%(epoch+1,n_epoch), nume / deno)
        print("Another one on TD: Epoch %d/%d  accuracy:"%(epoch+1,n_epoch), acc.item() / deno)
        print("-------------------------------------------------")
        
        

loss1: 1.4972822666168213    loss2: 14.995403289794922    loss3: 0.0
On source domain: Epoch 1/161  accuracy: 0.952 
-------------------------------------------------
Another one on TD: Epoch 1/161  accuracy: 0.6194196428571429
-------------------------------------------------
loss1: 1.4717826843261719    loss2: 8.824933052062988    loss3: 0.07994669675827026
loss1: 1.4704762697219849    loss2: 5.719654083251953    loss3: 0.6001102328300476
On source domain: Epoch 11/161  accuracy: 0.995 
loss1: 1.4651024341583252    loss2: 3.956559419631958    loss3: 0.3133990168571472
loss1: 1.4766459465026855    loss2: 3.198007345199585    loss3: 0.47490960359573364
On source domain: Epoch 21/161  accuracy: 0.997 
-------------------------------------------------
Another one on TD: Epoch 21/161  accuracy: 0.8699776785714286
-------------------------------------------------
loss1: 1.461299180984497    loss2: 3.0450539588928223    loss3: 0.35706719756126404
loss1: 1.4689639806747437    loss2: 2.753540

In [22]:
# for epoch in tqdm(range(n_epoch)):
for epoch in range(101):
    
    for data,labels in train_dataloader:
        data=data.to(device)
        labels=labels.to(device)
        # print(labels)
        X_t = X_t.to(device)
        Y_t = Y_t.to(device)
        optimizer.zero_grad()
        
        map_s = encoder(data)
        y_pred=classifier(map_s)
        loss1=loss_fn1(y_pred,labels)
        map_t = encoder(X_t)
        
        loss2 = 0
        loss3 = 0
        means_s = []
        # means_t = []
        
        for num in range(10):
#             subset = map_t[Y_t == num]
#             means_t.append(torch.mean(subset, dim = 0))
            subset = map_s[labels == num]
            if len(subset) <= 1: 
                danger = 1
            means_s.append(torch.mean(subset, dim = 0))
        for ctr in range(10*n_support):
            num = Y_t[ctr]
            dd = torch.stack([ map_t[ctr] - means_s[num] ]*10)
            Cplane = torch.stack( [(means_s[i-1] - means_s[i]) for i in range(10)] )
            loss2 += loss_fn2(-dd, Cplane, torch.FloatTensor([-1]*10).to(device))
            loss2 += loss_fn2(dd, Cplane, torch.FloatTensor([-1]*10).to(device))
            Cplane = torch.stack( [( map_t[ctr] - means_s[i] ) for i in range(10) if i != num] )
            loss3 += loss_fn3(dd[:-1], Cplane, torch.FloatTensor([-1]*9).to(device))
        if torch.isnan(loss1) or torch.isnan(loss2) or torch.isnan(loss3):
            continue
        loss = loss1 + gamma*loss2/(10*n_support) + theta*loss3/(10*n_support)
        loss.backward()
        optimizer.step()
        
    if epoch%5 == 0: print("loss1:", loss1.item(), "   loss2:", loss2.item(), "   loss3:", loss3.item())
        
    if epoch%10 == 0:
        acc=0
        for data,labels in test_dataloader:
            data=data.to(device)
            labels=labels.to(device)
            y_test_pred=classifier(encoder(data))
            acc+=(torch.max(y_test_pred,1)[1]==labels).float().mean().item()
        accuracy=round(acc / float(len(test_dataloader)), 3)
        print("On source domain: Epoch %d/%d  accuracy: %.3f "%(epoch+1,n_epoch,accuracy))
    
    if epoch%20 == 0:
        mapset_f = []
        labelset_f = []
        for data, labels in train_dataloader:
            data = data.to(device)
            labels=labels.to(device)
            map_f = encoder(data)
            mapset_f.append(map_f)
            labelset_f.append(labels)
        map_f = torch.cat(mapset_f[:-1])
        label_f = torch.cat(labelset_f[:-1])

        means_f = []
        for num in range(10):
            subset = map_f[label_f == num]
            means_f.append(torch.mean(subset, dim = 0))
#        nume = 0
        deno = 0
        acc = 0  
        for data, labels in real_test_loader:
            data = data.to(device)
            labels = labels.to(device)        
            map_ff = encoder(data)
            distTS = []
            for ii in range(10):
                distTS.append(torch.norm((map_ff - means_f[ii]), dim=1))
            distTS = torch.stack(distTS)
            acc+=torch.sum(torch.argmin(distTS, dim=0)==labels)
            
#             for ctr in range(batch_size_test):
#                 others = []
#                 for j in range(10):
#                     num = j
#                     tmp = map_ff[ctr] - means_f[num]
#                     dd_f = torch.stack([tmp]*9)
#                     tmp = [(means_f[num] - means_f[i]) for i in range(10) if i != num]
#                     Cplane_f = torch.stack(tmp)
#                     loss_f = loss_fn2(-dd_f, Cplane_f, torch.FloatTensor([-1]*9).to(device)) + loss_fn2(dd_f, Cplane_f, torch.FloatTensor([-1]*9).to(device))
#                     others.append(loss_f.item())
#                 # print(min(others))
#                 if np.argmin(others) == labels[ctr]: 
#                     nume+=1
            deno += len(labels)
        print("-------------------------------------------------")
#        print("On target domain: Epoch %d/%d  accuracy:"%(epoch+1,n_epoch), nume / deno)
        print("Another one on TD: Epoch %d/%d  accuracy:"%(epoch+1,n_epoch), acc.item() / deno)
        print("-------------------------------------------------")
        
        

loss1: 1.461153268814087    loss2: 0.9428699612617493    loss3: 0.19058966636657715
loss1: 1.4611526727676392    loss2: 1.050830602645874    loss3: 0.13861379027366638
On source domain: Epoch 11/161  accuracy: 1.000 
loss1: 1.461152195930481    loss2: 1.033910870552063    loss3: 0.0010127954883500934
loss1: 1.4611756801605225    loss2: 0.8750697374343872    loss3: 0.08127082884311676
On source domain: Epoch 21/161  accuracy: 1.000 
-------------------------------------------------
Another one on TD: Epoch 21/161  accuracy: 0.93359375
-------------------------------------------------
loss1: 1.4611507654190063    loss2: 0.7994027137756348    loss3: 0.010811077430844307
loss1: 1.465362548828125    loss2: 1.006070852279663    loss3: 0.013372169807553291
On source domain: Epoch 31/161  accuracy: 1.000 
loss1: 1.4611573219299316    loss2: 1.1495729684829712    loss3: 0.253337025642395
loss1: 1.4612082242965698    loss2: 0.9498468637466431    loss3: 0.0390334390103817
On source domain: Epoch 

In [11]:
# result check
mapset = []
labelset = []
for data, labels in train_dataloader:
    data=data.to(device)
    fmap = encoder(data).cpu().detach().numpy()
    labels=labels.to(device).cpu().detach().numpy()
    mapset.append(fmap)
    labelset.append(labels)

smap = np.vstack(mapset[:-1])
slabel = np.hstack(labelset[:-1])

means = []
dmeans = []

for num in range(10):
    subset1 = smap[slabel == num]
    means1 = np.mean(subset1, axis=0)
    tmp = subset1 - means1
    dists1 = np.linalg.norm(tmp, axis=1)
    means.append(means1)
    dmeans.append(np.mean(dists1))
# print(means[0])

for i in range(10):
    for j in range(i+1):
        print(np.linalg.norm(means[i] - means[j]), end=',')
    print()

[1.23879255e-03 5.22287451e-02 9.59805213e-03 3.86472861e-03
 3.61749339e+00 1.08772459e+01 1.68376043e-02 3.06823780e-03
 7.96256065e-02 5.04932739e-03 5.27169928e-03 1.30827236e+01
 2.79178098e-03 1.87596083e-01 2.74690171e-03 8.21155262e+00
 6.30398607e+00 9.01248958e-03 1.49523973e+01 5.73888254e+00
 7.42509794e+00 7.69134844e-04 3.46922083e-03 1.02345459e-02
 8.17765713e+00 5.01655787e-03 6.57811761e-02 7.07086001e-04
 1.82856526e-02 9.46878397e-04 1.05989046e+01 1.10120687e+01
 4.83924113e-02 9.65701199e+00 8.65636952e-03 7.36839008e+00
 3.64728928e+00 1.82187872e-03 9.03719556e-05 7.35651629e-05
 4.03609499e-03 3.06630530e-03 9.86695576e+00 2.26608276e+00
 2.35741446e-03 2.42280774e-03 4.69894335e-02 3.59939504e-03
 1.16868806e-03 3.41249455e-04 1.18673115e+01 1.00807764e-03
 3.05456114e+00 8.20095253e+00 6.77808762e+00 1.06108002e-02
 7.24169798e-03 3.85132022e-02 4.85484409e+00 2.05113064e-03
 2.43703742e-02 2.06792774e-03 3.07185412e-03 0.00000000e+00
 4.32378333e-03 1.493577

In [12]:
tmapset = []
tlabelset = []
for data, labels in real_test_loader:
    data=data.to(device)
    fmap = encoder(data).cpu().detach().numpy()
    labels=labels.to(device).cpu().detach().numpy()
    tmapset.append(fmap)
    tlabelset.append(labels)

tmap = np.vstack(tmapset[:-1])
tlabel = np.hstack(tlabelset[:-1])

tmeans = []
tdmeans = []
for num in range(10):
    subset1 = tmap[tlabel == num]
    means1 = np.mean(subset1, axis=0)
    tmp = subset1 - means1
    dists1 = np.linalg.norm(tmp, axis=1)
    tmeans.append(means1)
    tdmeans.append(np.mean(dists1))
print(dmeans)

[12.239137, 10.070674, 14.430747, 12.062499, 9.656775, 12.180727, 9.247488, 11.649317, 10.3363905, 8.067777]


In [13]:
for num in range(10):
    subset1 = tmap[tlabel == num]
    tsd = []
    for i in range(10):
        tmp = np.linalg.norm((subset1 - means[i]), axis=1)
        tsd.append(np.mean(tmp))
    print(tsd)

[39.3769, 59.36251, 59.26669, 60.47117, 56.73069, 60.30749, 59.089848, 59.025127, 58.65511, 54.43972]
[57.060223, 30.760845, 49.41133, 51.64115, 48.102978, 52.98966, 51.64625, 50.552414, 49.859398, 47.463226]
[52.957363, 44.977097, 28.1175, 46.078617, 43.06595, 49.11844, 47.42685, 45.62019, 45.558594, 43.595142]
[57.84237, 51.6468, 49.69099, 36.426094, 49.460583, 52.86915, 53.604626, 51.795273, 51.22441, 48.475395]
[52.19363, 45.416985, 45.487587, 47.445435, 27.51278, 46.69033, 45.324505, 44.57756, 43.937416, 38.480984]
[54.28899, 48.70962, 49.12572, 48.290226, 44.866276, 34.65721, 47.97307, 48.686527, 48.04868, 44.503345]
[57.05126, 50.918297, 51.772343, 53.702095, 47.700718, 51.711727, 33.251354, 51.614666, 50.984833, 49.11289]
[55.980247, 49.579895, 49.270756, 51.59235, 46.580143, 52.01401, 51.11465, 32.4685, 49.11644, 45.28119]
[55.622772, 48.26707, 49.037125, 50.8245, 46.03742, 51.419674, 50.37614, 48.91806, 33.908077, 44.40021]
[51.362907, 46.026337, 46.77273, 47.708145, 39.61631