In [223]:
import numpy as np
import torch
import torchvision
import cv2
import base64
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.nn.functional as F


In [224]:
# image_path = './mnist_theera_image.png'
# image = cv2.imread(image_path)
# image_array = np.array(image,dtype = np.uint8)
# # npArr = np.frombuffer(base64.b64decode(image_array), np.uint8)
# # img = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
# grayImage = cv2.cvtColor(image_array, cv2.COLOR_BGR2GRAY)
# gray_image = cv2.resize(grayImage, (28,28), interpolation=cv2.INTER_LINEAR)
# plt.imshow(gray_image)

In [243]:

class mnist_dataset(Dataset):
    def __init__(self, original_dataset):
        self.original_dataset = original_dataset
    def __len__(self):
        return len(self.original_dataset)

    def __getitem__(self,index):
        image, label = self.original_dataset[index]
        image = self.inject_matrix(image)
        return image, label

    def inject_matrix(self,smaller_tensor):
        # Get the shape of the matrices
        # Create a larger tensor of size (200, 700) filled with zeros
        larger_tensor = np.zeros((56,56), dtype=np.float32)

        # target_size = (150, 150)
        # smaller_tensor = smaller_tensor.unsqueeze(0)
        # # Use torch.nn.functional.interpolate to resize the tensor
        # smaller_tensor = F.interpolate(smaller_tensor, size=target_size, mode='bilinear', align_corners=False)
        smaller_tensor = smaller_tensor.squeeze()
        # Get the dimensions of the smaller tensor
        smaller_rows, smaller_cols = smaller_tensor.shape
        
        # Generate random starting coordinates within the valid range
        row_start = np.random.randint(0, larger_tensor.shape[0] - smaller_rows + 1)
        col_start = np.random.randint(0, larger_tensor.shape[1] - smaller_cols + 1)
        
        # Inject the smaller tensor into the larger tensor at the random location
        larger_tensor[row_start:row_start + smaller_rows, col_start:col_start + smaller_cols] = smaller_tensor
        return larger_tensor[np.newaxis,:]




In [244]:
class MyCnn(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3),
      torch.nn.ReLU(),
      torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3),
      torch.nn.ReLU(),
      torch.nn.Dropout2d(p=0.1)
    )
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
      torch.nn.ReLU(),
      torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
      torch.nn.ReLU(),
      torch.nn.Dropout2d(p=0.1)
    )
    self.linear_relu_stack = torch.nn.Sequential(
      torch.nn.Flatten(),
      torch.nn.Linear(in_features=147456, out_features=128),
      torch.nn.ReLU(),
      torch.nn.Dropout(p=0.1),
      torch.nn.Linear(in_features=128, out_features=10)
    )

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    logits = self.linear_relu_stack(x)
    return logits

In [245]:
# train_ds = torchvision.datasets.MNIST(root="data", train=True,  download=True, transform=torchvision.transforms.ToTensor())
# train_ds = mnist_dataset(train_ds)
# test_ds = torchvision.datasets.MNIST(root="data", train=False,  download=True, transform=torchvision.transforms.ToTensor())
# test_ds = mnist_dataset(test_ds)
# train_dl = torch.utils.data.DataLoader(train_ds, batch_size=1)
# test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1)
# for batch in train_dl:
#     # print(batch[0].squeeze().size())
#     # mnist_image = batch[0].squeeze()
#     a,b = batch
#     print(a.size())
#     print(b)
#     # plt.figure()
#     # plt.imshow(a[0])
#     # a = a.unsqueeze(0)
#     # a = F.interpolate(a, size=(28,28), mode='bilinear', align_corners=False)
#     # plt.figure()
#     plt.imshow(a.squeeze().squeeze())
#     break

In [247]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_ds = torchvision.datasets.MNIST(root="data", train=True,  download=True, transform=torchvision.transforms.ToTensor())
train_ds = mnist_dataset(train_ds)
test_ds = torchvision.datasets.MNIST(root="data", train=False,  download=True, transform=torchvision.transforms.ToTensor())
test_ds = mnist_dataset(test_ds)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=512)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=512)

myCnn = MyCnn().to(device)
loss = torch.nn.CrossEntropyLoss()
opt = torch.torch.optim.SGD(myCnn.parameters(), lr=1e-1)

for t in range(20):
    for idx, (x, y) in enumerate(train_dl):
      x, y = x.to(device), y.to(device)
      logits = myCnn(x)
      l = loss(logits, y)
      opt.zero_grad()
      l.backward()
      opt.step()
      if idx % 50 == 0: print(f"Epoch: {t}, Loss: {l.item()}")

    with torch.no_grad():
      correct = 0
      total = 0
      for x, y in test_dl:
        x, y = x.to(device), y.to(device)
        logits = myCnn(x)
        _, pred = torch.max(logits, dim=1)
        total += y.size(0)
        correct += (pred == y).sum().item()
      print(f"Epoch: {t}, Accuracy: {correct/total}")
    torch.save(myCnn.state_dict(), f"./mnist-{t}.pt")


Epoch: 0, Loss: 2.305835723876953
Epoch: 0, Loss: 2.299039363861084
Epoch: 0, Loss: 2.3031604290008545
Epoch: 0, Accuracy: 0.1281
Epoch: 1, Loss: 2.2974438667297363
Epoch: 1, Loss: 2.2706565856933594
Epoch: 1, Loss: 2.3041458129882812
Epoch: 1, Accuracy: 0.1135
Epoch: 2, Loss: 2.3003246784210205
Epoch: 2, Loss: 2.298281192779541
Epoch: 2, Loss: 2.305053234100342
Epoch: 2, Accuracy: 0.1135
Epoch: 3, Loss: 2.298513174057007
Epoch: 3, Loss: 2.298928737640381
Epoch: 3, Loss: 2.3055102825164795
Epoch: 3, Accuracy: 0.1135
Epoch: 4, Loss: 2.297867774963379
Epoch: 4, Loss: 2.2986040115356445
Epoch: 4, Loss: 2.304978370666504
Epoch: 4, Accuracy: 0.1135
Epoch: 5, Loss: 2.2975473403930664
Epoch: 5, Loss: 2.2971365451812744
Epoch: 5, Loss: 2.3023922443389893
Epoch: 5, Accuracy: 0.1161
Epoch: 6, Loss: 2.2940549850463867
Epoch: 6, Loss: 2.281979560852051
Epoch: 6, Loss: 2.2751331329345703
Epoch: 6, Accuracy: 0.126
Epoch: 7, Loss: 2.288757085800171
Epoch: 7, Loss: 2.253366231918335
Epoch: 7, Loss: 2.