In [1]:
import numpy as np
import pandas as pd
import os
import shutil
import glob
import matplotlib.pyplot as plt
import random
import torch
from torch import nn
import torchvision
import torch.utils.data
from torchvision.utils import save_image
import torch.optim as optim
from PIL import Image

from generative import main as generator
from solver import main as solver

from generative import GNetwork,DNetwork
from solver import Network,dataset,load_dataset

In [2]:
def plot_images(x):
    plt.figure(figsize=(5,5))
    for i in range(9):
        plt.subplot(3,3,i+1)
        plt.imshow(x[i].squeeze())

def permute_image_pixel(image,permutation):
    c,h,w=image.shape
    image=image.reshape(-1,c)
    image=image[permutation,:]
    return image.reshape(c,h,w)

def permute_data(data,permutation):
    for i in range(data.shape[0]):
        data[i]=permute_image_pixel(data[i],permutation)
        
    return data

In [3]:
def get_dataset(permutation):
    data=pd.read_csv("./digit-recognizer/train.csv").values
    
    #seprate imgaes and labels
    x,y=data[:,1:],data[:,0]
    
    #reshape data to (-1,c,h,w)
    x=x.reshape((-1,1,28,28))
    
    #permute data
    if(opt.permute):
        x=permute_data(x,permutation)

    #split data into train and test
    split_ratio=0.1
    total_data=x.shape[0]
    xtrain=x[int(total_data*split_ratio):]
    ytrain=y[int(total_data*split_ratio):]
    xtest=x[:int(total_data*split_ratio)]
    ytest=y[:int(total_data*split_ratio)]
    
    assert xtrain.shape[0]==ytrain.shape[0]
    assert xtest.shape[0]==ytest.shape[0]
        
    return xtrain,ytrain,xtest,ytest

In [4]:
class opt:
    batch_size=64
    epoch_generator=10
    epoch_solver=10
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    r=0.5
    shuffle=1
    permute=1

In [5]:
def validation_accuracy(xtest,ytest):
    testloader=load_dataset(xtest,ytest,opt.batch_size)
    net=Network().to(opt.device)
    net.load_state_dict(torch.load("./solvers/solver.pt"))
    net.eval()

    pred=[]
    for x,y in testloader:
        x,y=x.to(opt.device),y.to(opt.device)
        out=net(x)
        pred.extend((torch.argmax(out,dim=1)==y).to(float).detach().cpu().tolist())
    pred=np.array(pred)
    print(f"Validation Accuracy : {(sum(pred)/len(pred))*100}")
    
def shuffle_data(size,x,y):
    
    z=torch.randn(size, 50, 1, 1, device=opt.device)
    netG=GNetwork().to(opt.device)
    netG.load_state_dict(torch.load("./generators/gen.pt"))
    netG.eval()
    fakex=netG(z)
    
    net=Network().to(opt.device)
    net.load_state_dict(torch.load("./solvers/solver.pt"))
    net.eval()
    fakey=torch.argmax(net(fakex),dim=1)
    
    x[size:]=fakex.detach().cpu().numpy()
    y[size:]=fakey.detach().cpu().numpy()
    
    return x,y

In [6]:
for i in range(2):
    
    permutation=np.random.RandomState(seed=i).permutation(28*28) if opt.permute else 0
    xtrain,ytrain,xtest,ytest=get_dataset(permutation)
    train_data=xtrain.shape[0]
    test_data=xtest.shape[0]
    print(f"Total training data : {train_data}")
    print(f"Total validation data : {test_data}")
    
    dis="./generators/dis.pt"
    gen="./generators/gen.pt"
    sol="./solvers/solver.pt"
    
    if(i>0): # ignore for the first task
        
        train_size=int(train_data*opt.r)
        test_size=int(test_data*opt.r)
        
        if(opt.shuffle): # if generative replay
            xtrain,ytrain=shuffle_data(train_size,xtrain,ytrain)
            
        #shuffle test data in any case
        xtest,ytest=shuffle_data(test_size,xtest,ytest)
    
    generator(xtrain,opt.batch_size,opt.epoch_generator,gen,dis)
    solver(xtrain,ytrain,opt.batch_size,opt.epoch_solver,sol)
    
    validation_accuracy(xtest,ytest)

Total training data : 37800
Total validation data : 4200
DISCRIMNATOR LOSS : 180.59560762159526 GENERATOR LOSS : 2062.991374373436
DISCRIMNATOR LOSS : 118.2518248166889 GENERATOR LOSS : 2441.2467876672745
DISCRIMNATOR LOSS : 246.52121305465698 GENERATOR LOSS : 1642.6087901592255
DISCRIMNATOR LOSS : 242.94073355197906 GENERATOR LOSS : 1642.5745896697044
DISCRIMNATOR LOSS : 344.1497169137001 GENERATOR LOSS : 1391.3572830557823
DISCRIMNATOR LOSS : 381.8490227162838 GENERATOR LOSS : 1279.5707722306252
DISCRIMNATOR LOSS : 383.96165585517883 GENERATOR LOSS : 1311.5200257897377
DISCRIMNATOR LOSS : 433.80023887753487 GENERATOR LOSS : 1232.8664441108704
DISCRIMNATOR LOSS : 408.35310393571854 GENERATOR LOSS : 1248.4821232557297
DISCRIMNATOR LOSS : 413.26560840010643 GENERATOR LOSS : 1254.0531772375107
Epoch : 0 Loss : 328.84792252629995
Epoch : 1 Loss : 125.4957245029509
Epoch : 2 Loss : 83.94992697238922
Epoch : 3 Loss : 62.51199135184288
Epoch : 4 Loss : 48.749596955254674
Epoch : 5 Loss : 37.