In [1]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from zipfile import ZipFile
import skimage.io
from PIL import Image
import seaborn as sns
from sklearn.manifold import TSNE
import random
import umap
from itertools import product

from mnist_generator import get_mnist_loaders
from mnistm_generator import get_mnistm_loaders
from DANN import *
from DA import *
from test import *
from train import *
from visualize import *
from util import *

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
mnist_train_loader,mnist_eval_loader, mnist_test_loader = get_mnist_loaders(batch_size=128)

In [4]:
mnistm_train_loader, mnistm_eval_loader,mnistm_test_loader = get_mnistm_loaders(batch_size=128)

In [None]:
class S3GAN(nn.Module):
    def __init__(self):
        super(S3GAN, self).__init__()
        self.encoder = nn.Sequential(
            # 3x28x28
            nn.Conv2d(3,8, 3),
            nn.InstanceNorm2d(8),
            nn.ReLU(),
            # 8x26x26
            nn.Conv2d(8,16, 3, stride=2,padding=1)
            nn.InstanceNorm2d(16),
            nn.ReLU(),
            # 16x13x13
            nn.Conv2d(16,32, 3, stride=2,padding=1),
            nn.InstanceNorm2d(32),
            nn.ReLU(),
            # 32x8x8
            nn.Conv2d(32,64,3, stride=2,padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU()
            # 64x4x4
        )
        self.decoder = nn.Sequential(
            # 64x4x4
            nn.ConvTranspose2d(64, 32, 3,stride=2,padding=1),
            nn.InstanceNorm2d(32),
            nn.ReLU(),
            # 32x7x7
            nn.ConvTranspose2d(32,16,3,stride=2,padding=1,output_padding=1),
            nn.InstanceNorm2d(16),
            nn.ReLU(),
            # 16x14x14
            nn.ConvTranspose2d(16, 3,padding=1,output_padding=1)
            nn.InstanceNorm2d(3),
            nn.ReLU(),
            # 3x28x28
        )
    def forward(self, A,B):
        latentA = self.encoder(A)
        latentB = self.encoder(B)
        rec_A = self.decoder(latentA)
        rec_B = self.decoder(latentB)
        
        latentA.detach()
        latentB.detach()
        
        style = latentA[:,0:32,:,:]
        content = latentB[:,32:54,:,:]
        
        mixed_latent = torch.cat([style,content],dim=1)
        mixed_image = torch.decoder(mixed_latent)
        
        latent_rec_A = self.encoder(rec_A)
        latent_red_B = self.encoder(rec_B)
        latent_rec_A.detach()
        latent_rec_B.detach()
        
        rec_style = latent_rec_A[:,0:32,:,:]
        rec_content = latent_rec_B[:,32:64,:,:]
        
        mixed_rec_latent = torch.cat([rec_style,rec_content],dim=1)
        mixed_red_image = self.decoder(mixed_rec_latent)

In [11]:
discriminator_loss = torch.nn.CrossEntropyLoss().to(device)

In [8]:
def Gramian_matrix(x):
    _, d, h, w = x.size()
    x = x.view(d, h*w)
    gram = torch.mm(x,x.t())
    return gram

In [None]:
class Conceptual_style_loss(nn.Module):
    def __init__(self):
        super(Conceptual_style_loss, self).__init__()
        
    def forward(self, x1, x2):
        