In [24]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt


input_size=784 # 28*28 pixels of image flattened to 1D array of 784
classes=10 # 10 classes of digits
epochs=6
batch_size=100

training_dataset = datasets.MNIST(root="data",train=True,download=False,transform=ToTensor(),)
test_dataset = datasets.MNIST(root="data",train=False,download=False,transform=ToTensor(),)
train_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# neural network
class mnistNN(nn.Module):
    def __init__(self,input_size,classes):
        super(mnistNN,self).__init__()
        self.linear1=nn.Linear(input_size,128)
        self.relu1=nn.ReLU()
        self.linear2=nn.Linear(128,64)
        self.relu2=nn.ReLU()
        self.linear3=nn.Linear(64,classes)
        
    def forward(self,x):
        out=self.linear1(x)
        out=self.relu1(out)
        out=self.linear2(out)
        out=self.relu2(out)
        out=self.linear3(out)
        
        return out
    
model=mnistNN(input_size,classes)

# loss and optimizer
loss=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)

# training loop
for epoch in range(epochs):
    for batch_idx,(data,labels) in enumerate(train_loader):
        data=data.reshape(data.shape[0],-1)
        
        # forward
        scores=model(data)
        loss_=loss(scores,labels)
        
        # backward
        optimizer.zero_grad()
        loss_.backward()
        
        # gradient descent or adam step
        optimizer.step()
        if batch_idx%100==0:
            print(f"epoch {epoch+1} batch {batch_idx} loss {loss_}")

with torch.no_grad():
    n_correct=0
    n_samples=0
    for data,labels in test_loader:
        data=data.reshape(data.shape[0],-1)
        scores=model(data)

        _,predictions=torch.max(scores,1)
        n_samples+=labels.shape[0]
        n_correct+=(predictions==labels).sum().item()
    acc=100.0*n_correct/n_samples
    print(f"accuracy {acc}")
    

epoch 1 batch 0 loss 2.2921438217163086
epoch 1 batch 100 loss 0.18474571406841278
epoch 1 batch 200 loss 0.40350237488746643
epoch 1 batch 300 loss 0.19458512961864471
epoch 1 batch 400 loss 0.1777907758951187
epoch 1 batch 500 loss 0.3286444842815399
epoch 2 batch 0 loss 0.0732782632112503
epoch 2 batch 100 loss 0.11235550791025162
epoch 2 batch 200 loss 0.23901206254959106
epoch 2 batch 300 loss 0.1712895631790161
epoch 2 batch 400 loss 0.08825799077749252
epoch 2 batch 500 loss 0.1685595065355301
epoch 3 batch 0 loss 0.08111123740673065
epoch 3 batch 100 loss 0.13382571935653687
epoch 3 batch 200 loss 0.03789062798023224
epoch 3 batch 300 loss 0.04789008945226669
epoch 3 batch 400 loss 0.23688797652721405
epoch 3 batch 500 loss 0.034689635038375854
epoch 4 batch 0 loss 0.08817217499017715
epoch 4 batch 100 loss 0.1316395401954651
epoch 4 batch 200 loss 0.03181307017803192
epoch 4 batch 300 loss 0.054870251566171646
epoch 4 batch 400 loss 0.20736698806285858
epoch 4 batch 500 loss 0

In [40]:
import torch
import torchvision.models as models

# Assuming you have a model instance (replace this with your model)
model = models.resnet18()

# Save the entire model
torch.save(model, 'mnist_model.pth')


In [38]:
import numpy as np
import torch
from torchvision.io import read_image
from torchvision.transforms import ToTensor

image=read_image("image.jpg")
chnl, height, width = image.shape

image_seg=[] # list of 28x28 consequal images's tensors
for i in range(width//height):
    # split image into 28x28 chunks 
    img_tensor = image[:, :, i*height:(i+1)*height]
    image_seg.append(img_tensor)

# predicting value from these 28x28 chunks using model of mnist
prediction=0
for i in image_seg:
    i=i.reshape(1,-1) # datatype of i is torch.uint8
    i=i.type(torch.FloatTensor) # datatype of i is torch.float32
    pred=model(i)
    pred_val=torch.argmax(pred)
    prediction=prediction*10+pred_val.item()

print(prediction)



726766295


Corrupt JPEG data: premature end of data segment
