In [1]:
import numpy as np
import math
import torch
from PIL import Image
from torchvision import datasets, transforms
from torch.nn.utils import weight_norm
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torch.nn.functional as F
import copy
import numpy as np
import torchvision.utils as vutils
import time
import matplotlib.pyplot as plt   
import torchvision.utils as vutils
classes = 100
train_total = 50000


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:

#modified from: 
#https://github.com/kekmodel/FixMatch-pytorch/blob/master/dataset/cifar.py 
def x_u_split(labels, labeled_per_class):
    np.random.seed(0) #make sure split are the same
    labels = np.array(labels)
    labeled_idx = []
    for i in range(classes):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, labeled_per_class, False)
        labeled_idx.extend(idx)
    labeled_idx = np.array(labeled_idx)
    np.random.shuffle(labeled_idx)
    np.random.seed() #optionally restore rng
    return labeled_idx, np.setdiff1d(range(train_total), labeled_idx)

class CIFAR100SSL(datasets.CIFAR100):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=True):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


In [4]:
mean,std = (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)

# might change data agumentation
transform_labeled = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)])

transform_unlabeled = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)]) #AddGaussianNoise()

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean,std)])

In [5]:
root = './tmp'
#adjust this for labeled and unlabed split
labeled_per_class = 100

base_dataset = datasets.CIFAR100(
    root, train=True, download=True)

train_labeled_idxs, train_unlabeled_idxs = x_u_split(base_dataset.targets, labeled_per_class)

train_labeled_dataset = CIFAR100SSL(
    root, train_labeled_idxs, train=True,
    transform=transform_labeled)

train_unlabeled_dataset = CIFAR100SSL(
    root, None, train=True,
    transform=transform_unlabeled)

test_dataset = datasets.CIFAR100(
    root, train=False, transform=transform_test, download=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./tmp/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./tmp/cifar-100-python.tar.gz to ./tmp
Files already downloaded and verified
Files already downloaded and verified


In [6]:
u_train_batch_size = l_train_batch_size = val_batch_size = batch_size = 500
l_train_dataloader = DataLoader(train_labeled_dataset, batch_size=batch_size, shuffle=True)
u_train_dataloader = DataLoader(train_unlabeled_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

In [7]:
nz = 100
nc = 3
lr_gen = lr_cla = lr = 0.0003
num_epochs = 300
train_interval_gen = 1
train_interval_dis = 1
clamp_max = 0.01
ngf = 64 # num of feature maps in generator
ndf = 32 # num of feature maps in generator
fm_layer = 'avgpool' #for Resnet feature_matching layer: 'layer1','layer2','layer3','layer4','avgpool','fc'
mul = 1 #scale loss on labelled data to match with wasserstein loss
Feature_match = True #whether to use feature matching
label_smooth = 1

In [8]:

class Discriminator(nn.Module):
        def __init__(self,num_classes):
            super(Discriminator, self).__init__()
            self.net = nn.Sequential(
                    nn.Dropout(.2),
                    weight_norm(nn.Conv2d(3,ndf * 3,3,stride=1,padding=1)),
                    nn.LeakyReLU(),
                    weight_norm(nn.Conv2d(ndf * 3,ndf * 3,3,stride=1,padding=1)),
                    nn.LeakyReLU(),
                    weight_norm(nn.Conv2d(ndf * 3,ndf * 3,3,stride=2,padding=1)),
                    nn.LeakyReLU(),

                    nn.Dropout(.5),
                    weight_norm(nn.Conv2d(ndf * 3,ndf * 6,3,stride=1,padding=1)),
                    nn.LeakyReLU(),
                    weight_norm(nn.Conv2d(ndf * 6,ndf * 6,3,stride=1,padding=1)),
                    nn.LeakyReLU(),
                    weight_norm(nn.Conv2d(ndf * 6,ndf * 6,3,stride=2,padding=1)),
                    nn.LeakyReLU(),
                    
                    nn.Dropout(.5),
                    weight_norm(nn.Conv2d(ndf * 6,ndf * 6,3,stride=1,padding=0)),
                    nn.LeakyReLU(),
                    weight_norm(nn.Conv2d(ndf * 6,ndf * 6,1,stride=1,padding=0)),
                    nn.LeakyReLU(),
                    weight_norm(nn.Conv2d(ndf * 6,ndf * 6,1,stride=1,padding=0)),
                    nn.LeakyReLU(),

                    # nn.AvgPool2d(6,stride=1),
                    nn.AdaptiveAvgPool2d(1),
                    nn.Flatten()
                )

            self.fc = weight_norm(nn.Linear(ndf * 6,num_classes))
            
        def forward(self,x):
            inter_layer = self.net(x)
            logits = self.fc(inter_layer)
            return inter_layer, logits

In [9]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution


            # state size. (ngf*16) x 4 x 4
            nn.ConvTranspose2d( ngf*8, ngf * 4, 4, 2, 1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(),
            # state size. (ngf*8) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(),
            # state size. (ngf*4) x 16 x 16
            weight_norm(nn.ConvTranspose2d( ngf * 2, nc, 4, 2, 1)),
            nn.Tanh()
            # state size. (nc) x 32 x 32
        )
        self.enlargenoise = nn.Sequential(
            nn.Linear(nz, ngf * 8 * 4 * 4),
            nn.BatchNorm1d(ngf * 8 * 4 * 4),
            nn.ReLU(),
        )

    def forward(self, input):
        x = self.enlargenoise(input.squeeze())
        x = x.view((-1,ngf*8,4,4))
        return self.main(x)

In [10]:
def init_normal(m):
        if type(m) == nn.Linear:
            nn.init.normal_(m.weight,mean=.0,std=.05)
            nn.init.constant_(m.bias,.0)

        if type(m) == nn.ConvTranspose2d:
            nn.init.normal_(m.weight,mean=0,std=.05)


In [11]:
class FeatureExtractor(nn.Module):

    def __init__(self, submodule, extracted_layers):

        super(FeatureExtractor, self).__init__()

        self.submodule = submodule.cuda()

        self.extracted_layers = extracted_layers

    def forward(self, x):
        out_dict={}
        for name, module in self.submodule._modules.items():

            if name is "fc": x = x.view(x.size(0), -1)

            x = module(x)

            if name in self.extracted_layers:

                out_dict[name]=x
        return out_dict

In [12]:
def test(model, test_loader, m_test, display = False):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for label_image, label in test_loader:
            label_image,label = label_image.cuda(),label.cuda()
            _,output = model(label_image)
            test_loss += criterion_C(output, label).item() 
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()

    test_loss /= len(test_loader)

    if display == True:
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, m_test,
        100. * correct / m_test))
           
    return test_loss, 100. * correct / m_test/len(test_loader)


In [13]:
def wasserstein_loss(y_true, y_pred):
    return torch.mean(y_true * y_pred)
def log_sum_exp(x, axis = 1):
    m = torch.max(x, dim = 1)[0]    
    return m + torch.log(torch.sum(torch.exp(x- m.unsqueeze(1)), dim = axis)) 

In [14]:
netG = Generator().apply(init_normal).cuda()
netDC = Discriminator(classes).apply(init_normal).cuda()

In [15]:
optimizer_G  = torch.optim.RMSprop(netG.parameters(), lr=lr)
optimizer_DC  = torch.optim.RMSprop(netDC.parameters(), lr=lr)
scheduler_G = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, mode='min', factor=0.5, patience=30, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
scheduler_DC = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_DC, mode='min', factor=0.5, patience=30, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

criterion_C = nn.CrossEntropyLoss().cuda()
criterion_G = nn.MSELoss().cuda()

fixed_noise = torch.randn(batch_size,nz).cuda()


G_losses = []
DC_losses = []

best_model_wts = copy.deepcopy(netDC.state_dict())
best_acc = 0.0

In [None]:


for epoch in range(num_epochs):
    time1 = time.time()

    for u_data, _ in u_train_dataloader:

        l_data, label = next(iter(l_train_dataloader))
        u_data = u_data.cuda()
        l_data = l_data.cuda()
        label = label.cuda()
        noise = torch.randn(u_data.size(0), nz).cuda() #label or unlabel

        netDC.train()
        optimizer_DC.zero_grad()

        _,l_logit = netDC(l_data)
        _,u_logit = netDC(u_data)
        generated = netG(noise)
        _,fake_logit = netDC(generated.detach())

        loss_D = 0.5*wasserstein_loss(log_sum_exp(u_logit)-F.softplus(log_sum_exp(u_logit)),-1)+0.5*wasserstein_loss(F.softplus(log_sum_exp(fake_logit) ), 1)
        loss_C = criterion_C(l_logit, label)
        loss_DC = 0.5*loss_D + 0.5*loss_C
        loss_DC.backward()
        optimizer_DC.step()
        DC_losses.append(loss_DC.item())


        optimizer_DC.zero_grad()
        optimizer_G.zero_grad()
        generated = netG(noise)
        feature_fake,_ = netDC(generated)
        feature_real,_ = netDC(u_data)

        loss_G = criterion_G(torch.mean(feature_fake,dim=0) ,torch.mean(feature_real,dim=0))
        loss_G.backward()
        optimizer_G.step()
        G_losses.append(loss_G.item())
    
    
    with torch.no_grad():
        netDC.eval()
        test(netDC, l_train_dataloader, len(train_labeled_dataset), True)
        val_loss, val_accuracy = test(netDC, test_dataloader, len(test_dataset), True)
    scheduler_DC.step(val_loss)
    scheduler_G.step(val_loss)

    if val_accuracy > best_acc:
        best_acc = val_accuracy
        best_classifier_wts = copy.deepcopy(netDC.state_dict())
        best_generator_wts = copy.deepcopy(netG.state_dict())
        torch.save(best_generator_wts, '/content/drive/MyDrive/601.682/Final_PJ/model_cifar/netG_best100.pth')
        torch.save(best_classifier_wts, '/content/drive/MyDrive/601.682/Final_PJ/model_cifar/netDC_best100.pth')



    time2 = time.time()
    print('ETA of completion:',(time2 - time1)*(num_epochs - epoch)/60,'minutes: curr_epoch',epoch)
    fig = plt.figure(figsize=(10,5))
    generated = (netG(fixed_noise)+1.0)/2.0
    grid=vutils.make_grid(generated[0:16].cpu().detach())
    plt.imshow(np.transpose(grid,(1,2,0)))
    plt.show()
    fig = plt.figure(figsize=(10,5))
    plt.subplot(121)
    plt.plot(range(0,len(G_losses*train_interval_gen),train_interval_gen), np.array(G_losses),label='generator_loss')
    plt.ylim(0,50)
    plt.legend()
    plt.subplot(122)
    plt.plot(range(0,len(DC_losses*train_interval_dis),train_interval_dis), np.array(DC_losses),label='disriminator_loss')
    plt.ylim(0,5)
    plt.legend()
    plt.show()
