In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from torch import nn
import torch
from torch.autograd import Variable
import math

class coRNNCell(nn.Module):
    def __init__(self, n_inp, n_hid, dt, gamma, epsilon):
        super(coRNNCell, self).__init__()
        self.dt = dt
        self.gamma = gamma
        self.epsilon = epsilon
        self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)
#         self.bn = nn.BatchNorm1d(n_hid)
#         self.dropout = nn.Dropout(0.5)

    def forward(self,x,hy,hz):
        hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))
                                   - self.gamma * hy - self.epsilon * hz)
        hy = hy + self.dt * hz

        return hy, hz

class coRNN(nn.Module):
    def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon):
        super(coRNN, self).__init__()
        self.n_hid = n_hid
        self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon)
        self.readout = nn.Linear(n_hid, n_out)
#         self.bn = nn.BatchNorm1d(n_hid)
#         self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        ## initialize hidden states
        hy = Variable(torch.zeros(x.size(1),self.n_hid))
        hz = Variable(torch.zeros(x.size(1),self.n_hid))

        for t in range(x.size(0)):
            hy, hz = self.cell(x[t],hy,hz)
        output = self.readout(hy)

        return output


In [None]:
class Myclass:
    def __init__(self):
      self.n_hid = 256
      self.T = 100
      self.max_steps = 60000
      self.log_interval = 100
      self.batch = 120
      self.batch_test = 1000
      self.lr = 0.0021
      self.dt = 0.042
      self.gamma = 2.7
      self.epsilon = 4.7
      self.epochs = 120

args = Myclass()

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

def get_data(bs_train,bs_test):
    train_dataset = torchvision.datasets.MNIST(root='data/',
                                               train=True,
                                               transform=transforms.ToTensor(),
                                               download=True)

    test_dataset = torchvision.datasets.MNIST(root='data/',
                                              train=False,
                                              transform=transforms.ToTensor())

    train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [57000,3000])

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=bs_train,
                                               shuffle=True)

    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,
                                              batch_size=bs_test,
                                              shuffle=False)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=bs_test,
                                              shuffle=False)

    return train_loader, valid_loader, test_loader

In [None]:
from torch import nn, optim
import torch
# import network
import torch.nn.utils
# import utils
from pathlib import Path
import argparse

torch.manual_seed(46159)

n_inp = 1
n_out = 10
bs_test = 1000

model = coRNN(n_inp, args.n_hid, n_out,args.dt,args.gamma,args.epsilon)

objective = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)

In [None]:
model

In [None]:
train_loader, valid_loader, test_loader = get_data(args.batch,bs_test)
# print(train_loader.shape)

In [None]:
from PIL import Image
downloaded_images_count = 0
num_images_to_download = 20
output_directory = 'downloaded_images/'
os.makedirs(output_directory, exist_ok=True)
# Iterate over the test_loader
for batch_idx, (images, labels) in enumerate(test_loader):
    # Iterate over images in the batch
    for image_idx in range(images.size(0)):
        # Check if the desired number of images has been downloaded
        if downloaded_images_count >= num_images_to_download:
            break

        # Extract the image and label
        image = images[image_idx].squeeze().numpy()
        label = labels[image_idx].item()

        # Convert the NumPy array to a PIL Image
        pil_image = Image.fromarray((image * 255).astype('uint8'), mode='L')

        # Save the image with its label as the filename
        filename = f"{label}_{downloaded_images_count}.png"
        pil_image.save(os.path.join(output_directory, filename))

        # Increment the downloaded images count
        downloaded_images_count += 1

    # Check if the desired number of images has been downloaded
    if downloaded_images_count >= num_images_to_download:
        break

print(f"{downloaded_images_count} images downloaded to {output_directory}")


In [None]:
for i, (images, labels) in enumerate(train_loader):
    print(images.shape)
    break

In [None]:
def test(data_loader):
    model.eval()
    correct = 0
    test_loss = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(data_loader):
            images = images.reshape(bs_test, 1, 784)
            images = images.permute(2, 0, 1)

            output = model(images)
            test_loss += objective(output, labels).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(labels.data.view_as(pred)).sum()
    test_loss /= i+1
    accuracy = 100. * correct / len(data_loader.dataset)

    return accuracy.item()




In [None]:
clip_value = 0.5
t=0
j=0
best_valid_accuracy = 0.0
best_model_state_dict = None
for epoch in range(args.epochs):
    t=0
    j=j+1
    print("big loop %d",j)
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(args.batch, 1, 784)
        images = images.permute(2, 0, 1)

        optimizer.zero_grad()
        output = model(images)
#         print(output.shape)
        t=t+1
        print(t)
        loss = objective(output, labels)
        loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()
    valid_acc = test(valid_loader)
    test_acc = test(test_loader)
    if valid_acc > best_valid_accuracy:
        best_valid_accuracy = valid_acc
        best_model_state_dict = model.state_dict()
        torch.save(model.state_dict(), 'pminst_model_checkpoint.pth')
        print(best_valid_accuracy)
    Path('result').mkdir(parents=True, exist_ok=True)
    f = open('result/sMNIST_log.txt', 'a')
    if (epoch == 0):
        f.write('## learning rate = ' + str(args.lr) + ', dt = ' + str(args.dt) + ', gamma = ' + str(args.gamma) + ', epsilon = ' + str(args.epsilon) + '\n')
    f.write('eval accuracy: ' + str(round(valid_acc, 2)) + '\n')
    f.close()

    if (epoch+1) % 100 == 0:
        args.lr /= 10.
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr

if best_model_state_dict is not None:
    model.load_state_dict(best_model_state_dict)

# Test the model
test_acc = test(test_loader)

print('Test set:  Accuracy: {:.2f}%\n'.format(test_acc))
f = open('result/sMNIST_log.txt', 'a')
f.write('final test accuracy: ' + str(round(test_acc, 2)) + '\n')
f.close()

In [None]:
test_acc = test(test_loader)

print('Test set:  Accuracy: {:.2f}%\n'.format(test_acc))
f = open('result/sMNIST_log.txt', 'a')
f.write('final test accuracy: ' + str(round(test_acc, 2)) + '\n')
f.close()