In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from torch import nn
import torch
from torch.autograd import Variable
import math
from torch.nn import init
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class coRNNCell(nn.Module):
    def __init__(self, n_inp, n_hid, dt, gamma, epsilon):
        super(coRNNCell, self).__init__()
        self.dt = dt
        self.gamma = gamma
        self.epsilon = epsilon
        self.i2h = nn.Linear(n_inp, n_hid,bias=True)
        self.h2h1 = nn.Linear(n_hid,n_hid,bias=False)
        self.h2h2 = nn.Linear(n_hid,n_hid,bias=False)

    def forward(self,x,hy,hz):
        hz = hz + self.dt * (torch.tanh(self.h2h1(hy) + self.h2h2(hz) + self.i2h(x))
                             - self.gamma * hy - self.epsilon * hz)
        hy = hy + self.dt * hz

        return hy, hz

class coRNN(nn.Module):
    def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon):
        super(coRNN, self).__init__()
        self.n_hid = n_hid
        self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon)
        self.readout = nn.Linear(n_hid, n_out)

    def forward(self, x):
        ## initialize hidden states
        hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)
        hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)

        for t in range(x.size(0)):
            hy, hz = self.cell(x[t],hy,hz)
        output = self.readout(hy)

        return output


In [None]:
from torch import nn, optim
import torch
# import utils
# import network
import argparse
import torch.nn.utils
from pathlib import Path



class Myclass:
    def __init__(self):
      self.n_hid = 128
      self.T = 100
      self.embedding = 100
      self.max_steps = 60000
      self.log_interval = 100
      self.batch = 100
      self.batch_test = 1000
      self.lr = 0.0075
      self.dt = 0.034
      self.gamma = 1.3
      self.epsilon = 12.7
      self.epochs = 10

args = Myclass()


In [None]:
n_inp = 96
n_out = 10
model = coRNN(n_inp, args.n_hid, n_out,args.dt, args.gamma, args.epsilon)
## Define the loss
objective = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)


In [None]:
model.to(device)

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

def get_data(bs_train,bs_test):
    train_dataset = torchvision.datasets.CIFAR10(root='data/',
                                                 train=True,
                                                 transform=transforms.ToTensor(),
                                                 download=True)

    test_dataset = torchvision.datasets.CIFAR10(root='data/',
                                                train=False,
                                                transform=transforms.ToTensor())

    train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [47000,3000])

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=bs_train,
                                               shuffle=True)

    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                              batch_size=bs_test,
                                              shuffle=False)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=bs_test,
                                              shuffle=False)

    return train_loader, valid_loader, test_loader

In [None]:
train_loader, valid_loader, test_loader = get_data(args.batch,1000)


In [None]:
i=0
for images,labels in train_loader:
    print(images.shape)
    print(labels.shape)
    break

In [None]:
rands = torch.randn(1, 1000 - 32, 96)
rand_train = rands.repeat(args.batch,1,1)
rand_test = rands.repeat(1000,1,1)

def test(data_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for images, labels in data_loader:
            ## Reshape images for sequence learning:
            images = torch.cat((images.permute(0,2,1,3).reshape(1000,32,96),rand_test),dim=1).permute(1,0,2)
            output = model(images.to(device))
            output.to(device)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(labels.to(device).data.view_as(pred.to(device))).sum()
    accuracy = 100. * correct / len(data_loader.dataset)

    return accuracy.item()

In [None]:
best_eval = 0.
j=0
for epoch in range(args.epochs):
    j=j+1
    print("big loop %d",j)
    t=0
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        t=t+1
        print(t)
        ## Reshape images for sequence learning:
        images = torch.cat((images.permute(0,2,1,3).reshape(args.batch,32,96),rand_train),dim=1).permute(1,0,2)
        # Training pass
        optimizer.zero_grad()
        output = model(images.to(device))
        loss = objective(output.to(device), labels.to(device))
        loss.backward()
        optimizer.step()

    valid_acc = test(valid_loader)
    test_acc = test(test_loader)
    if (valid_acc > best_eval):
        best_eval = valid_acc
        final_test_acc = test_acc

    Path('result').mkdir(parents=True, exist_ok=True)
    f = open('result/noisy_cifar_log.txt', 'a')
    if (epoch == 0):
        f.write('## learning rate = ' + str(args.lr) + ', dt = ' + str(args.dt) + ', gamma = ' + str(
            args.gamma) + ', epsilon = ' + str(args.epsilon) + '\n')
    f.write('eval accuracy: ' + str(round(valid_acc, 2)) + '\n')
    f.close()

    if (epoch + 1) % 100 == 0:
        args.lr /= 10.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr

f = open('result/noisy_cifar_log.txt', 'a')
f.write('final test accuracy: ' + str(round(final_test_acc, 2)) + '\n')
f.close()

In [None]:
import pickle

In [None]:
pickle.dump(model,open("cornn.h5","wb"))


In [None]:
rf_loaded = pickle.load(open("rf.h5","rb"))
preds  = rf_loaded.predict(X[:1000])
mae = mean_absolute_error(y[:1000],preds)
