In [40]:
import time
import numpy as np
import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# from tqdm.notebook import tqdm

In [41]:
from FADAnet.S2Mloader import *
from FADAnet.FADAmodule import *

In [42]:
# loss version: loss = loss1 + gamma*loss2/(10*n_support) + theta*loss3/(10*n_support)
n_epoch = 161  # total number of epoch
n_epoch_pt = 0
batch_size = 256
batch_size_test = 512
lr = 0.0007
gamma = 0.05
theta = 0.04
n_support = 7
loss3_margin = 0.7

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
np.random.seed(777)
torch.manual_seed(777)
torch.cuda.manual_seed(777)

In [43]:
train_dataloader=svhn_dataloader(batch_size=batch_size,train=True)
test_dataloader=svhn_dataloader(train=False)
real_test_loader = mnist_dataloader(batch_size = batch_size_test, train=False)

classifier=Classifier()
encoder=Encoder()

classifier.to(device)
encoder.to(device)
loss_fn1 = torch.nn.CrossEntropyLoss()
loss_fn2 = torch.nn.CosineEmbeddingLoss()
loss_fn3 = torch.nn.CosineEmbeddingLoss(margin=loss3_margin)
optimizer=torch.optim.Adadelta(list(encoder.parameters())+list(classifier.parameters()))

X_t, Y_t = create_target_samples(n=n_support)

Using downloaded and verified file: ./data/SVHN/train_32x32.mat
Using downloaded and verified file: ./data/SVHN/test_32x32.mat


In [44]:
# for epoch in tqdm(range(n_epoch)):
for epoch in range(n_epoch):
    
    for data,labels in train_dataloader:
        data=data.to(device)
        labels=labels.to(device)
        X_t = X_t.to(device)
        Y_t = Y_t.to(device)
        optimizer.zero_grad()
        
        map_s = encoder(data)
        y_pred=classifier(map_s)
        loss1=loss_fn1(y_pred,labels)
        map_t = encoder(X_t)
        
        loss2 = 0
        loss3 = 0
        means_s = []
        # means_t = []
        for num in range(10):
#             subset = map_t[Y_t == num]
#             means_t.append(torch.mean(subset, dim = 0))
            subset = map_s[labels == num]
            means_s.append(torch.mean(subset, dim = 0))
        for ctr in range(10*n_support):
            num = Y_t[ctr]
            dd = torch.stack([ map_t[ctr] - means_s[num] ]*10)
            Cplane = torch.stack( [(means_s[i-1] - means_s[i]) for i in range(10)] )
            loss2 += loss_fn2(-dd, Cplane, torch.FloatTensor([-1]*10).to(device))
            loss2 += loss_fn2(dd, Cplane, torch.FloatTensor([-1]*10).to(device))
            Cplane = torch.stack( [( map_t[ctr] - means_s[i] ) for i in range(10) if i != num] )
            loss3 += loss_fn3(dd[:-1], Cplane, torch.FloatTensor([-1]*9).to(device))
            
        loss = loss1 + gamma*loss2/(10*n_support) + theta*loss3/(10*n_support)
        loss.backward()
        optimizer.step()
        
    if epoch%5 == 0: print("loss1:", loss1.item(), "   loss2:", loss2.item(), "   loss3:", loss3.item())
        
    if epoch%10 == 0:
        acc=0
        for data,labels in test_dataloader:
            data=data.to(device)
            labels=labels.to(device)
            y_test_pred=classifier(encoder(data))
            acc+=(torch.max(y_test_pred,1)[1]==labels).float().mean().item()
        accuracy=round(acc / float(len(test_dataloader)), 3)
        print("On source domain: Epoch %d/%d  accuracy: %.3f "%(epoch+1,n_epoch,accuracy))
    
    if epoch%20 == 0:
        mapset_f = []
        labelset_f = []
        for data, labels in train_dataloader:
            data = data.to(device)
            labels=labels.to(device)
            map_f = encoder(data)
            mapset_f.append(map_f)
            labelset_f.append(labels)
        map_f = torch.cat(mapset_f[:-1])
        label_f = torch.cat(labelset_f[:-1])

        means_f = []
        for num in range(10):
            subset = map_f[label_f == num]
            means_f.append(torch.mean(subset, dim = 0))
#         nume = 0
        deno = 0
        acc = 0  
        for data, labels in real_test_loader:
            data = data.to(device)
            labels = labels.to(device)        
            map_ff = encoder(data)
            distTS = []
            for ii in range(10):
                distTS.append(torch.norm((map_ff - means_f[ii]), dim=1))
            distTS = torch.stack(distTS)
            acc+=torch.sum(torch.argmin(distTS, dim=0)==labels)
            
#             for ctr in range(batch_size_test):
#                 others = []
#                 for j in range(10):
#                     num = j
#                     tmp = map_ff[ctr] - means_f[num]
#                     dd_f = torch.stack([tmp]*9)
#                     tmp = [(means_f[num] - means_f[i]) for i in range(10) if i != num]
#                     Cplane_f = torch.stack(tmp)
#                     loss_f = loss_fn2(-dd_f, Cplane_f, torch.FloatTensor([-1]*9).to(device)) + loss_fn2(dd_f, Cplane_f, torch.FloatTensor([-1]*9).to(device))
#                     others.append(loss_f.item())
#                 # print(min(others))
#                 if np.argmin(others) == labels[ctr]: 
#                     nume+=1
            deno += batch_size_test
            if deno > 5100: break
        print("-------------------------------------------------")
        # print("On target domain: Epoch %d/%d  accuracy:"%(epoch+1,n_epoch), nume / deno)
        print("Another one on TD: Epoch %d/%d  accuracy:"%(epoch+1,n_epoch), acc.item() / deno)
        print("-------------------------------------------------")
        
        

loss1: 2.2746148109436035    loss2: 1.6422038078308105    loss3: 17.466873168945312
On source domain: Epoch 1/161  accuracy: 0.196 
-------------------------------------------------
Another one on TD: Epoch 1/161  accuracy: 0.1294921875
-------------------------------------------------
loss1: 1.6324539184570312    loss2: 14.774789810180664    loss3: 1.2258985042572021
loss1: 1.578744888305664    loss2: 11.0239839553833    loss3: 0.4115108549594879
On source domain: Epoch 11/161  accuracy: 0.774 
loss1: 1.5579211711883545    loss2: 8.534467697143555    loss3: 0.21657395362854004
loss1: 1.5798814296722412    loss2: 7.406755447387695    loss3: 0.619644045829773
On source domain: Epoch 21/161  accuracy: 0.803 
-------------------------------------------------
Another one on TD: Epoch 21/161  accuracy: 0.7271484375
-------------------------------------------------
loss1: 1.5007237195968628    loss2: 9.710536003112793    loss3: 0.1150103509426117
loss1: 1.4876405000686646    loss2: 7.9934058

In [46]:
# result check
mapset = []
labelset = []
for data, labels in train_dataloader:
    data=data.to(device)
    fmap = encoder(data).cpu().detach().numpy()
    labels=labels.to(device).cpu().detach().numpy()
    mapset.append(fmap)
    labelset.append(labels)

smap = np.vstack(mapset[:-1])
slabel = np.hstack(labelset[:-1])

means = []
dmeans = []

for num in range(10):
    subset1 = smap[slabel == num]
    means1 = np.mean(subset1, axis=0)
    tmp = subset1 - means1
    dists1 = np.linalg.norm(tmp, axis=1)
    means.append(means1)
    dmeans.append(np.mean(dists1))
print(means[0])

for i in range(10):
    for j in range(i+1):
        print(np.linalg.norm(means[i] - means[j]), end=',')
    print()

[-23.688744    21.273893     3.1445177    1.1988074    6.0620027
 -11.410247    27.245071    -7.2591133  -10.65264    -28.3089
 -22.033226     2.842914   -14.670648     6.21544     42.38876
  39.800205    24.564138     7.4917865   34.946407     8.950566
  28.74279     -0.63787794  18.698229   -35.204746   -16.776817
   9.953578   -12.181389    18.404148    14.666326   -19.785471
 -15.118991     6.8794265  -13.031214     0.9114728  -24.350803
  27.645535     2.0251348    3.2660334  -23.596907    14.472404
 -12.164628    60.08781     42.950176    24.292904    -7.956047
 -38.490963   -14.454111    33.338936    24.302477   -25.859348
  16.611355   -28.318966   -23.777386    27.504766    36.42862
  10.824631   -19.561558   -40.200413    49.791332     5.3936415
 -19.12181    -52.97179      0.12562715 -10.164382  ]
0.0,
236.48732,0.0,
287.08545,259.55295,0.0,
275.6476,243.73723,251.43686,0.0,
296.27084,231.9413,259.86368,260.98044,0.0,
323.48795,292.15512,302.3135,250.78674,289.3931,0.0,
264.

In [47]:
tmapset = []
tlabelset = []
for data, labels in real_test_loader:
    data=data.to(device)
    fmap = encoder(data).cpu().detach().numpy()
    labels=labels.to(device).cpu().detach().numpy()
    tmapset.append(fmap)
    tlabelset.append(labels)

tmap = np.vstack(tmapset[:-1])
tlabel = np.hstack(tlabelset[:-1])

tmeans = []
tdmeans = []
for num in range(10):
    subset1 = tmap[tlabel == num]
    means1 = np.mean(subset1, axis=0)
    tmp = subset1 - means1
    dists1 = np.linalg.norm(tmp, axis=1)
    tmeans.append(means1)
    tdmeans.append(np.mean(dists1))
print(dmeans)

[131.37401, 106.545456, 127.75493, 113.06761, 117.81129, 142.30075, 150.75305, 78.9538, 125.301, 115.88948]


In [48]:
for num in range(10):
    subset1 = tmap[tlabel == num]
    tsd = []
    for i in range(10):
        tmp = np.linalg.norm((subset1 - means[i]), axis=1)
        tsd.append(np.mean(tmp))
    print(tsd)

[290.9956, 373.67953, 389.38113, 393.7371, 405.95767, 426.6419, 387.4414, 376.52304, 387.43335, 363.84274]
[323.97153, 230.49045, 322.41327, 322.7934, 305.6804, 358.49924, 363.0555, 292.81622, 327.4344, 322.03897]
[436.72092, 413.77103, 300.74545, 398.9159, 403.47165, 435.0979, 451.06293, 394.66595, 422.61343, 418.6802]
[379.4224, 360.74832, 362.70453, 263.1432, 368.55246, 357.39172, 387.0257, 346.47733, 348.38406, 361.4477]
[409.5532, 364.7748, 380.37994, 382.64645, 272.24857, 399.53793, 396.94843, 364.98962, 386.66483, 357.1262]
[469.2989, 448.62766, 454.87463, 405.29953, 436.79926, 294.37714, 372.53497, 428.3165, 433.06668, 425.2241]
[376.67355, 375.45593, 380.82263, 370.85666, 325.3451, 353.11145, 302.06384, 359.21637, 359.46265, 380.14117]
[355.14615, 331.12982, 346.1563, 338.7127, 341.7284, 355.0337, 364.51785, 239.53, 344.7452, 327.47513]
[365.9263, 354.0274, 358.7895, 324.70456, 352.53604, 361.3994, 341.49405, 329.69452, 258.12878, 329.79816]
[326.7711, 326.8759, 333.70468, 330

In [49]:
print(tmeans[0])

[-25.745762    15.9033985    4.660685    -3.5865479  -51.985085
 -75.84789     69.041664   -24.451033    36.76213     16.207006
 -49.550293    18.481579   -33.656685   -22.327173   -26.14766
   4.639961    -0.1664241  -16.648289    85.47253    -12.875888
  -5.276546   -12.34953     -8.334483   -45.763218   -20.043188
  11.530189    -3.2756162   13.591866    39.04248     -5.4091463
 -12.622531    -4.3483872  -13.063196    35.20257    -21.803411
  36.26048    -29.631714     2.0178175   -9.971764   -12.942284
  14.583926    96.17354     62.914017    81.518      -34.537804
 -33.287106   -63.502804    61.22401    -36.03754    -58.169857
   9.957944   -36.862183   -26.718363    28.1729      31.984795
   6.60877     31.57246    -97.49038     72.86181     40.584488
  -0.12022278 -33.237915    40.598946     5.8581066 ]
