In [6]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import numpy as np
import pickle , gzip
from sklearn.svm import LinearSVC
from sklearn.multiclass import OneVsRestClassifier


seed = 999
torch.backends.cudnn.deterministic = True
# torch.cuda.manual_seed_all(cfg.seed)
torch.manual_seed(seed)
np.random.seed(seed)

In [1]:
def load_rotated_mnist(datapath='MNIST_6rot.pkl.gz',left_out_idx=0):

    domains = pickle.load(gzip.open(datapath,'rb'),encoding='iso-8859-1')
    src_domains = domains[:] # clone the list
    del src_domains[left_out_idx]

    (x_test, y_test) = domains[left_out_idx]
    y_test = domains[left_out_idx][1]

    return src_domains, (x_test, y_test)

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
     
        self.encoder = torch.nn.Sequential(
            nn.Linear(16*16, 1500,bias=True).cuda(),
            nn.ReLU(True).cuda()
        ).cuda()
    
        self.decoder = torch.nn.Sequential(
            nn.Linear(1500, 16*16,bias=True).cuda(),
            nn.ReLU(True).cuda()
        ).cuda()


    def forward(self, x):
        out = self.encoder(x).cuda()
        out = self.decoder(out).cuda()
        
        return out
    def features(self,x):
        out = self.encoder(x).cuda()
        return out

In [7]:
class MTAE:
    def __init__(self):
        super(MTAE,self).__init__()
          
        learning_rate = 0.03
        weight_decay = 3e-4
        self.autoEncoders = []
        self.optimizers = []

        for i in range(0, 5):
            self.autoEncoders.append(
                Autoencoder().cuda()
            )
            self.optimizers.append(torch.optim.ASGD(params=self.autoEncoders[i].parameters(), lr=learning_rate,weight_decay=weight_decay))
            
            
        self.criterion = nn.MSELoss()
        self.bias = self.autoEncoders[0].encoder[0].bias
        self.weight = self.autoEncoders[0].encoder[0].weight
        
            
        
    def train(self,X,Y,domainId):
 
        self.autoEncoders[domainId].encoder[0].weight = self.weight
        self.autoEncoders[domainId].encoder[0].bias = self.bias
       
        self.optimizers[domainId].zero_grad()
        model = self.autoEncoders[domainId]
   
        model = model.train()
        hypothesis = model(X)
        
        cost = self.criterion(hypothesis, Y)

        cost.backward()
        
        self.optimizers[domainId].step()
        
        
        self.bias = self.autoEncoders[domainId].encoder[0].bias
        self.weight = self.autoEncoders[domainId].encoder[0].weight     
    
    def get_features(self,X):
      
        self.autoEncoders[0].encoder[0].weight = self.weight
        self.autoEncoders[0].encoder[0].bias = self.bias

        model = self.autoEncoders[0]
        model = model.eval()
        return model.features(X)

In [4]:
def linear_svm(training_data,test_data):
    x = training_data[:,1:]
    y = []
    for i in training_data:
        y.append(i[0])

    y = np.array(y)
#     linear_svm = LinearSVC(max_iter=10000,cache_size=800)
    linear_svm = OneVsRestClassifier(LinearSVC(max_iter=10000), n_jobs=-1)

    linear_svm.fit(x,y)


    train_outcome = linear_svm.predict(x)
    train_result = []
    for i in range(len(x)):
        if train_outcome[i] == y[i]:
            train_result.append(True)
        else:
            train_result.append(False)

    train_accuracy = train_result.count(True)/len(x)

#     print("Training Accuracy: %.8f" % train_accuracy)
    train_error = 1-train_accuracy
#     print("Training Error: %.8f" % train_error)

    outcome = linear_svm.predict(test_data[:,1:])
    result= []
    for i in range(len(test_data)):
        if outcome[i] == test_data[i][0]:
            result.append(True)
        else:
            result.append(False)

    accuracy = result.count(True)/len(test_data)
#     print("Test accuracy: %.8f" % accuracy)
    test_error = 1-accuracy
#     print("Test Error: %.8f" % test_error)

    return accuracy

In [5]:
def domain_specific_training(dom):
    domains,(x_test,y_test) = load_rotated_mnist(left_out_idx=dom)
    x_test = torch.tensor(x_test).float().cuda()

    accuracy = 0

    for repeat in range(1):


        mtae = MTAE()

        for epoch in range(150):
          
            x_train = []
            y_train = []

            random = np.random.permutation(1000)

            for i in range(5):
                x = domains[i][0]
                x_permuted = x[random]

                y = domains[i][1]
                y_permuted = y[random]
                x_train.append(x_permuted)
                y_train.append(y_permuted)

            x_train = torch.tensor(x_train).float().cuda()


            for i in range(5):
                for j in range(5):
                    for k in range(0,1000,50):
                        left_x = x_train[i][k:k+50,:]

                        right_x = x_train[j][k:k+50,:]

                        mtae.train(left_x,right_x,i)
 

        feat = mtae.get_features(x_train)

        feat = feat.cpu().detach().numpy()
        feat = feat.reshape((5000,1500))

        y_train = np.array(y_train)
        y_train = y_train.reshape(5000,1)

        train = np.hstack((y_train,feat))


        feat = mtae.get_features(x_test)
        feat = feat.cpu().detach().numpy()
        feat = feat.reshape((1000,1500))


        y_test = y_test.reshape(1000,1)

        test = np.hstack((y_test,feat))
        

        accuracy += linear_svm(train,test)

    print("Accuracy :", accuracy/1)

In [None]:
for dom in range(6):
    print("----------------------Domain.{}---------------------".format(dom))
    domain_specific_training(dom)