<a href="https://colab.research.google.com/github/Kolessov/Deep-weight-prior/blob/main/FADA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import time
import numpy as np
import argparse

import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [5]:
class BasicModule(torch.nn.Module):
    def __init__(self):
        super(BasicModule,self).__init__()

    def load(self,path):
        self.load_state_dict(torch.load(path))

    def save(self, path=None):
        if path is None:
            name='result/best_model.pth'
            torch.save(self.state_dict(),name)
            return name
        else:
            torch.save(self.state_dict(),path)
            return path

In [6]:
class DCD(BasicModule):
    def __init__(self,h_features=64,input_features=2*84):
        super(DCD,self).__init__()

        self.fc1=nn.Linear(input_features,h_features)
        self.fc2=nn.Linear(h_features,h_features)
        self.fc3=nn.Linear(h_features,4)

    def forward(self,inputs):
        out=F.relu(self.fc1(inputs))
        out=self.fc2(out)
        return F.softmax(self.fc3(out),dim=1)


In [7]:
class Classifier(BasicModule):
    def __init__(self,input_features=84):
        super(Classifier,self).__init__()
        self.fc=nn.Linear(input_features,10)

    def forward(self,input):
        return F.softmax(self.fc(input),dim=1)

In [8]:
class Encoder(BasicModule):
    def __init__(self):
        super(Encoder,self).__init__()

        self.conv1=nn.Conv2d(1,6,5)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(256,120)
        self.fc2=nn.Linear(120,84)
        #self.fc3=nn.Linear(84,64)

    def forward(self,input):
        out=F.relu(self.conv1(input))
        out=F.max_pool2d(out,2)
        out=F.relu(self.conv2(out))
        out=F.max_pool2d(out,2)
        out=out.view(out.size(0),-1)

        out=F.relu(self.fc1(out))
        #out=F.relu(self.fc2(out))
        out=self.fc2(out)

        return out


In [9]:
def usps_dataloader(batch_size=256,train=True):

    dataloader=DataLoader(
    datasets.USPS('./data',train=train,download=True,
                   transform=transforms.Compose([ 
                       transforms.ToTensor(),
                       transforms.Resize((28,28)),
                       transforms.Normalize((0.5,),(0.5,))
                   ])),
    batch_size=batch_size,shuffle=True)

    return dataloader

In [10]:
def svhn_dataloader(batch_size=4,train=True):
    dataloader = DataLoader(
        datasets.SVHN('./data', split=('train' if train else 'test'), download=True,
                       transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.Grayscale(),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,))
                       ])),
        batch_size=batch_size, shuffle=False)

    return dataloader

In [None]:
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

dt = datasets.MNIST('./',train=True,download=True)

--2021-03-29 16:56:28--  http://www.di.ens.fr/~lelarge/MNIST.tar.gz
Resolving www.di.ens.fr (www.di.ens.fr)... 129.199.99.14
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.di.ens.fr/~lelarge/MNIST.tar.gz [following]
--2021-03-29 16:56:29--  https://www.di.ens.fr/~lelarge/MNIST.tar.gz
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/x-gzip]
Saving to: ‘MNIST.tar.gz’

MNIST.tar.gz            [       <=>          ]  33.20M  5.03MB/s    in 19s     

2021-03-29 16:56:49 (1.77 MB/s) - ‘MNIST.tar.gz’ saved [34813078]

MNIST/
MNIST/raw/
MNIST/raw/train-labels-idx1-ubyte
MNIST/raw/t10k-labels-idx1-ubyte.gz
MNIST/raw/t10k-labels-idx1-ubyte
MNIST/raw/t10k-images-idx3-ubyte.gz
MNIST/raw/train-images-idx3-ubyte
MNIST/raw/train-labels-idx1-ubyte.gz
MNIST/raw/t10k-images-idx3-ubyte
MNIST/raw/tra

In [11]:
def sample_data():
    dataset=datasets.USPS('./data',train=True,download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Resize((28,28)),
                       transforms.Normalize((0.5, ),(0.5, ))
                   ]))
    n=len(dataset)

    X=torch.Tensor(n,1,28,28)
    Y=torch.LongTensor(n)

    inds=torch.randperm(len(dataset))
    for i,index in enumerate(inds):
        x,y=dataset[index]
        X[i]=x
        Y[i]=y
    return X,Y



In [12]:
def create_target_samples( mask_matrix):

    dataset=datasets.SVHN('./data', split='train', download=True,
                       transform=transforms.Compose([
                           transforms.Resize((28,28)),
                           transforms.Grayscale(),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5, ), (0.5,))
                       ]))
    
    """
    X,Y=[],[]
    classes=10*[n]

    i=0
    while True:
        if len(X)==n*10:
            break
        x,y=dataset[i]
        if classes[y]>0:
            X.append(x)
            Y.append(y)
            classes[y]-=1
        i+=1

    assert (len(X)==n*10)
    """

    X,Y =[],[]

    for label, num_elements in enumerate(mask_matrix):

      idxs = torch.nonzero( torch.tensor( [dataset[i][1] == label for i in range(len(dataset))]  ) )
      idxs = idxs[torch.randperm(len(idxs))]

      counter = 0
      while counter < num_elements:
        X.append(dataset[idxs[counter]][0])
        Y.append(dataset[idxs[counter]][1])
        counter += 1
      


    return torch.stack(X,dim=0),torch.from_numpy(np.array(Y))

In [13]:
#opt=vars(parser.parse_args())

use_cuda=True if torch.cuda.is_available() else False
device=torch.device('cuda:0') if use_cuda else torch.device('cpu')
torch.manual_seed(1)
if use_cuda:
    torch.cuda.manual_seed(1)

In [14]:
X_s,Y_s= sample_data()
#X_t,Y_t = create_target_samples()

In [15]:
X_t,Y_t = create_target_samples([2,2,3,6,5,5,4,7,2,3])

Using downloaded and verified file: ./data/train_32x32.mat


In [16]:
def create_groups(X_s,Y_s,X_t,Y_t,mask_matrix,seed=1):
    #change seed so every time wo get group data will different in source domain,but in target domain, data not change

    torch.manual_seed(1 + seed)
    torch.cuda.manual_seed(1 + seed)


    n=X_t.shape[0] # 38


    #shuffle order
    classes = torch.unique(Y_t) # 0 1 .. 9
    #classes=classes[torch.randperm(len(classes))] # 7 8 ..2


    class_num=classes.shape[0] 
    #shot=n//class_num # 7



    def s_idxs(c):
        idx=torch.nonzero(Y_s.eq(int(c)))

        return idx[torch.randperm(len(idx))][:7*2].squeeze() 

    def t_idxs(c):
        return torch.nonzero(Y_t.eq(int(c)))[:7].squeeze()

    source_idxs = list(map(s_idxs, classes))
    target_idxs = list(map(t_idxs, classes))

    source_matrix = source_idxs
    target_matrix = target_idxs
 


    G1, G2, G3, G4 = [], [] , [] , []
    Y1, Y2 , Y3 , Y4 = [], [] ,[] ,[]

   
    for i in range(10):
        for j in range(mask_matrix[i]):
            G1.append(( X_s[source_matrix[i][j*2]], X_s[source_matrix[i][j*2+1]]))
            Y1.append((Y_s[source_matrix[i][j*2]],Y_s[source_matrix[i][j*2+1]]))

            
            G2.append((X_s[source_matrix[i][j]],X_t[target_matrix[i][j]]))
            Y2.append((Y_s[source_matrix[i][j]],Y_t[target_matrix[i][j]]))

            G3.append((X_s[source_matrix[i%10][j]],X_s[source_matrix[(i+1)%10][j]]))
            Y3.append((Y_s[source_matrix[i % 10][j]], Y_s[source_matrix[(i + 1) % 10][j]]))
             
           

            G4.append((X_s[source_matrix[i%10][j]],X_t[target_matrix[(i+1)%10][0]]))
            Y4.append((Y_s[source_matrix[i % 10][j]], Y_t[target_matrix[(i + 1) % 10][0]]))

 
    groups=[G1,G2,G3,G4]
    groups_y=[Y1,Y2,Y3,Y4]

    #make sure we sampled enough samples
    #for g in groups:
        #assert(len(g)==n)
    return groups,groups_y

In [17]:
def sample_groups(X_s,Y_s,X_t,Y_t,mask_matrix,seed=1):


    print("Sampling groups")
    return create_groups(X_s,Y_s,X_t,Y_t,mask_matrix, seed=seed)

In [18]:
batch_size=32
n_epochs_1 = 10
#n_target_samples = 7
n_epochs_2 = 100
n_epochs_3 = 100

In [19]:
train_dataloader= usps_dataloader(batch_size=batch_size,train=True)
test_dataloader= usps_dataloader(batch_size=batch_size,train=False)

classifier= Classifier()
encoder= Encoder()
discriminator= DCD(input_features=2*84)

classifier.to(device)
encoder.to(device)
discriminator.to(device)
loss_fn=torch.nn.CrossEntropyLoss()

optimizer=torch.optim.Adam(list(encoder.parameters())+list(classifier.parameters()))


for epoch in range(n_epochs_1):

    for data,labels in train_dataloader:

        data=data.to(device)
        labels=labels.to(device)

        optimizer.zero_grad()

        y_pred=classifier(encoder(data))

        loss=loss_fn(y_pred,labels)
        loss.backward()

        optimizer.step()

    acc=0
    for data,labels in test_dataloader:
        data=data.to(device)
        labels=labels.to(device)
        y_test_pred=classifier(encoder(data))
        acc+=(torch.max(y_test_pred,1)[1]==labels).float().mean().item()

    accuracy=round(acc / float(len(test_dataloader)), 3)

    print("step1----Epoch %d/%d  accuracy: %.3f "%(epoch+1,n_epochs_1,accuracy))

step1----Epoch 1/10  accuracy: 0.854 
step1----Epoch 2/10  accuracy: 0.900 
step1----Epoch 3/10  accuracy: 0.912 
step1----Epoch 4/10  accuracy: 0.913 
step1----Epoch 5/10  accuracy: 0.923 
step1----Epoch 6/10  accuracy: 0.929 
step1----Epoch 7/10  accuracy: 0.936 
step1----Epoch 8/10  accuracy: 0.928 
step1----Epoch 9/10  accuracy: 0.933 
step1----Epoch 10/10  accuracy: 0.937 


In [20]:
groups,aa =  sample_groups(X_s,Y_s,X_t,Y_t,[2,2,3,6,5,5,4,7,2,3])

Sampling groups


In [21]:
mask_matrix = [2,2,3,6,5,5,4,7,2,3]

In [22]:
optimizer_D=torch.optim.Adam(discriminator.parameters(),lr=0.001)


for epoch in range(n_epochs_2):
    # data
    groups,aa =  sample_groups(X_s,Y_s,X_t,Y_t,mask_matrix,seed=epoch)
    
    

    n_iters = 4 * len(groups[1])
    index_list = torch.randperm(n_iters)
    mini_batch_size=20 #use mini_batch train can be more stable


    loss_mean=[]

    X1=[]
    X2=[]
    ground_truths=[]

    for index in range(n_iters):

        ground_truth = index_list[index]//len(groups[1]) # from which subgroup

        x1,x2 = groups[ground_truth][index_list[index]-len(groups[1])*ground_truth]
        X1.append(x1)
        X2.append(x2)
        ground_truths.append(ground_truth)

        #select data for a mini-batch to train
        if (index+1)%mini_batch_size==0:

            X1=torch.stack(X1)
            X2=torch.stack(X2)
            ground_truths=torch.LongTensor(ground_truths)
            X1=X1.to(device)
            X2=X2.to(device)
            ground_truths=ground_truths.to(device)

            optimizer_D.zero_grad()
            X_cat=torch.cat([encoder(X1),encoder(X2)],1)


            y_pred=discriminator(X_cat.detach())


            loss=loss_fn(y_pred,ground_truths)
            loss.backward()
            optimizer_D.step()
            loss_mean.append(loss.item())
            X1 = []
            X2 = []
            ground_truths = []

    print("step2----Epoch %d/%d loss:%.3f"%(epoch+1,n_epochs_2,np.mean(loss_mean)))


Sampling groups
step2----Epoch 1/100 loss:1.387
Sampling groups
step2----Epoch 2/100 loss:1.318
Sampling groups
step2----Epoch 3/100 loss:1.262
Sampling groups
step2----Epoch 4/100 loss:1.208
Sampling groups
step2----Epoch 5/100 loss:1.165
Sampling groups
step2----Epoch 6/100 loss:1.132
Sampling groups
step2----Epoch 7/100 loss:1.084
Sampling groups
step2----Epoch 8/100 loss:1.045
Sampling groups
step2----Epoch 9/100 loss:1.030
Sampling groups
step2----Epoch 10/100 loss:1.004
Sampling groups
step2----Epoch 11/100 loss:0.966
Sampling groups
step2----Epoch 12/100 loss:0.952
Sampling groups
step2----Epoch 13/100 loss:0.939
Sampling groups
step2----Epoch 14/100 loss:0.921
Sampling groups
step2----Epoch 15/100 loss:0.904
Sampling groups
step2----Epoch 16/100 loss:0.927
Sampling groups
step2----Epoch 17/100 loss:0.900
Sampling groups
step2----Epoch 18/100 loss:0.887
Sampling groups
step2----Epoch 19/100 loss:0.866
Sampling groups
step2----Epoch 20/100 loss:0.862
Sampling groups
step2----Epoc

In [None]:
optimizer_g_h=torch.optim.Adam(list(encoder.parameters())+list(classifier.parameters()),lr=0.001)
optimizer_d=torch.optim.Adam(discriminator.parameters(),lr=0.001)


test_dataloader= svhn_dataloader(train=False,batch_size = batch_size)

acc_list = []
acc_class_list = []

for epoch in range(n_epochs_3):
    #---training g and h , DCD is frozen

    groups, groups_y = sample_groups(X_s,Y_s,X_t,Y_t,mask_matrix,seed= n_epochs_2 + epoch)
    G1, G2, G3, G4 = groups
    Y1, Y2, Y3, Y4 = groups_y

    groups_2 = [G2, G4]
    groups_y_2 = [Y2, Y4]

    n_iters = 2 * len(G2)
    index_list = torch.randperm(n_iters)

    n_iters_dcd = 4 * len(G2)
    index_list_dcd = torch.randperm(n_iters_dcd)

    mini_batch_size_g_h = 20 #data only contains G2 and G4 ,so decrease mini_batch

    mini_batch_size_dcd= 40 #data contains G1,G2,G3,G4 so use 40 as mini_batch




    X1 = []
    X2 = []
    ground_truths_y1 = []
    ground_truths_y2 = []
    dcd_labels=[]


    for index in range(n_iters):


        ground_truth=index_list[index]//len(G2)
        x1, x2 = groups_2[ground_truth][index_list[index] - len(G2) * ground_truth]
        y1, y2 = groups_y_2[ground_truth][index_list[index] - len(G2) * ground_truth]
        # y1=torch.LongTensor([y1.item()])
        # y2=torch.LongTensor([y2.item()])
        dcd_label=0 if ground_truth==0 else 2
        X1.append(x1)
        X2.append(x2)
        ground_truths_y1.append(y1)
        ground_truths_y2.append(y2)
        dcd_labels.append(dcd_label)

        if (index+1)%mini_batch_size_g_h==0:

            X1=torch.stack(X1)
            X2=torch.stack(X2)
            ground_truths_y1=torch.LongTensor(ground_truths_y1)
            ground_truths_y2 = torch.LongTensor(ground_truths_y2)
            dcd_labels=torch.LongTensor(dcd_labels)
            X1=X1.to(device)
            X2=X2.to(device)
            ground_truths_y1=ground_truths_y1.to(device)
            ground_truths_y2 = ground_truths_y2.to(device)
            dcd_labels=dcd_labels.to(device)

            optimizer_g_h.zero_grad()

            encoder_X1=encoder(X1)
            encoder_X2=encoder(X2)

            X_cat=torch.cat([encoder_X1,encoder_X2],1)
            y_pred_X1=classifier(encoder_X1)
            y_pred_X2=classifier(encoder_X2)
            y_pred_dcd=discriminator(X_cat)

            loss_X1=loss_fn(y_pred_X1,ground_truths_y1)
            loss_X2=loss_fn(y_pred_X2,ground_truths_y2)
            loss_dcd=loss_fn(y_pred_dcd,dcd_labels)

            loss_sum = loss_X1 + loss_X2 + 0.2 * loss_dcd

            loss_sum.backward()
            optimizer_g_h.step()

            X1 = []
            X2 = []
            ground_truths_y1 = []
            ground_truths_y2 = []
            dcd_labels = []


    #----training dcd ,g and h frozen
    X1 = []
    X2 = []
    ground_truths = []
    for index in range(n_iters_dcd):

        ground_truth=index_list_dcd[index]//len(groups[1])

        x1, x2 = groups[ground_truth][index_list_dcd[index] - len(groups[1]) * ground_truth]
        X1.append(x1)
        X2.append(x2)
        ground_truths.append(ground_truth)

        if (index + 1) % mini_batch_size_dcd == 0:
            X1 = torch.stack(X1)
            X2 = torch.stack(X2)
            ground_truths = torch.LongTensor(ground_truths)
            X1 = X1.to(device)
            X2 = X2.to(device)
            ground_truths = ground_truths.to(device)

            optimizer_d.zero_grad()
            X_cat = torch.cat([encoder(X1), encoder(X2)], 1)
            y_pred = discriminator(X_cat.detach())
            loss = loss_fn(y_pred, ground_truths)
            loss.backward()
            optimizer_d.step()
            # loss_mean.append(loss.item())
            X1 = []
            X2 = []
            ground_truths = []

    #testing
    acc_cl = []
    acc = 0
    for data, labels in test_dataloader:
        data = data.to(device)
        labels = labels.to(device)
        y_test_pred = classifier(encoder(data))
        acc += (torch.max(y_test_pred, 1)[1] == labels).float().mean().item()
        
        cur_data =   torch.tensor([0]*10).cuda()
        test_pred =  torch.tensor([0]*10).cuda()
        acc_class =  torch.tensor([0]*10).cuda()

        
        print("hui")
        for num_label in range(10):
          
           
          cur_data[num_label] = torch.cat([data_.unsqueeze(1) for data_, labels_ in zip(data, labels) if labels_ == num_label])
          #label_class[num_label] = torch.cat([labels_ for data_, labels_ in zip(data, labels) if labels_ == num_label])
          test_pred[num_label] = classifier(encoder(cur_data[num_label]))

          acc_class[num_label] +=  (torch.max(test_pred[num_label], 1)[1] == 0).float().mean().item()
       

    accuracy = round(acc / float(len(test_dataloader)), 3)
    
    acc_class_total = [round(element/ float(len(test_loader)),3) for element in acc_class]
    acc_list.append(accuracy)
    acc_cl.extend(acc_class_total)
    

    print("step3----Epoch %d/%d  accuracy: %.3f " % (epoch + 1,  n_epochs_3, accuracy))

Using downloaded and verified file: ./data/test_32x32.mat
