## Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from tqdm import tqdm

## Load Dataset

In [2]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomRotation(10),
    torchvision.transforms.Resize((70,70)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
dataset = torchvision.datasets.ImageFolder(root="../input/persian-mnist/MNIST_persian",transform=transform)

train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

In [8]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3,32,(3,3),(1,1),(1,1))
        self.conv2 = nn.Conv2d(32,64,(3,3),(1,1),(1,1))
        self.conv3 = nn.Conv2d(64,128,(3,3),(1,1),(1,1))
        self.conv4 = nn.Conv2d(128,256,(3,3),(1,1),(1,1))
        
        self.fully_connect1 = nn.Linear(256*4*4,256)
        self.fully_connect2 = nn.Linear(256,128)
        self.fully_connect3 = nn.Linear(128,10)
        
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x,kernel_size=(2,2))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x,kernel_size=(2,2))
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x,kernel_size=(2,2))
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x,kernel_size=(2,2))
        x = torch.flatten(x,start_dim=1)
        x = F.relu(self.fully_connect1(x))
        x = torch.dropout(x,0.2,train=True)
        x = F.relu(self.fully_connect2(x))
        x = torch.dropout(x,0.5,train=True)
        x = F.softmax(self.fully_connect3(x),dim=1)
        
        return x

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model().to(device)
model.train = True

In [10]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
loss_function = nn.CrossEntropyLoss()

In [11]:
def accuracy(preds,labels):
    _,pred_max = torch.max(preds,1)
    acc = torch.sum(pred_max==labels,dtype=torch.float64) / len(preds)
    
    return acc

In [12]:
for epoch in range(30):
    train_loss = 0.0
    train_acc = 0.0
    test_loss = 0.0
    test_acc = 0.0
    for image, label in tqdm(train):
        image.to(device)
        label.to(device)
        optimizer.zero_grad()
        
        predict = model(image)
        loss = loss_function(predict,label)
        loss.backward()
        optimizer.step()
        
        train_loss += loss
        train_acc += accuracy(predict,label)
    
    for image, label in test:
        image.to(device)
        label.to(device)
        predict = model(image)
        loss = loss_function(predict,label)
        
        test_loss += loss
        test_acc += accuracy(predict,label)
    
    total_train_loss = train_loss / len(train)
    total_train_acc = train_acc / len(train)
    total_test_loss = test_loss / len(test)
    total_test_acc = test_acc / len(test)
    
    print(f"Epochs: {epoch+1}, Accuracy: {total_train_acc}, Loss: {total_train_loss}")
    print(f"Epochs: {epoch+1}, Val_Accuracy: {total_test_acc}, Val_Loss: {total_test_loss}")

100%|██████████| 38/38 [00:13<00:00,  2.91it/s]


Epochs: 1, Accuracy: 0.09950657894736842, Loss: 2.3028862476348877
Epochs: 1, Val_Accuracy: 0.1015625, Val_Loss: 2.303030014038086


100%|██████████| 38/38 [00:08<00:00,  4.27it/s]


Epochs: 2, Accuracy: 0.10115131578947369, Loss: 2.302900791168213
Epochs: 2, Val_Accuracy: 0.1015625, Val_Loss: 2.3029346466064453


100%|██████████| 38/38 [00:07<00:00,  4.91it/s]


Epochs: 3, Accuracy: 0.10608552631578948, Loss: 2.3024017810821533
Epochs: 3, Val_Accuracy: 0.17708333333333334, Val_Loss: 2.3000500202178955


100%|██████████| 38/38 [00:08<00:00,  4.75it/s]


Epochs: 4, Accuracy: 0.22532894736842105, Loss: 2.2115023136138916
Epochs: 4, Val_Accuracy: 0.3619791666666667, Val_Loss: 2.075199604034424


100%|██████████| 38/38 [00:08<00:00,  4.64it/s]


Epochs: 5, Accuracy: 0.4152960526315789, Loss: 2.048793315887451
Epochs: 5, Val_Accuracy: 0.5234375, Val_Loss: 1.9349243640899658


100%|██████████| 38/38 [00:07<00:00,  4.94it/s]


Epochs: 6, Accuracy: 0.5493421052631579, Loss: 1.9122942686080933
Epochs: 6, Val_Accuracy: 0.5911458333333334, Val_Loss: 1.8665812015533447


100%|██████████| 38/38 [00:07<00:00,  4.95it/s]


Epochs: 7, Accuracy: 0.5731907894736842, Loss: 1.8822224140167236
Epochs: 7, Val_Accuracy: 0.5651041666666666, Val_Loss: 1.9002069234848022


100%|██████████| 38/38 [00:07<00:00,  4.86it/s]


Epochs: 8, Accuracy: 0.6208881578947368, Loss: 1.840102195739746
Epochs: 8, Val_Accuracy: 0.6432291666666666, Val_Loss: 1.8200523853302002


100%|██████████| 38/38 [00:08<00:00,  4.55it/s]


Epochs: 9, Accuracy: 0.6875, Loss: 1.7725944519042969
Epochs: 9, Val_Accuracy: 0.7708333333333334, Val_Loss: 1.6999225616455078


100%|██████████| 38/38 [00:07<00:00,  4.82it/s]


Epochs: 10, Accuracy: 0.7335526315789473, Loss: 1.7285302877426147
Epochs: 10, Val_Accuracy: 0.75, Val_Loss: 1.711547613143921


100%|██████████| 38/38 [00:07<00:00,  4.84it/s]


Epochs: 11, Accuracy: 0.7483552631578947, Loss: 1.7125781774520874
Epochs: 11, Val_Accuracy: 0.8255208333333334, Val_Loss: 1.6393795013427734


100%|██████████| 38/38 [00:07<00:00,  4.79it/s]


Epochs: 12, Accuracy: 0.7960526315789473, Loss: 1.671675682067871
Epochs: 12, Val_Accuracy: 0.8619791666666666, Val_Loss: 1.610640287399292


100%|██████████| 38/38 [00:08<00:00,  4.56it/s]


Epochs: 13, Accuracy: 0.8092105263157895, Loss: 1.6512322425842285
Epochs: 13, Val_Accuracy: 0.8776041666666666, Val_Loss: 1.5967243909835815


100%|██████████| 38/38 [00:07<00:00,  4.84it/s]


Epochs: 14, Accuracy: 0.8396381578947368, Loss: 1.6253942251205444
Epochs: 14, Val_Accuracy: 0.8411458333333334, Val_Loss: 1.6282970905303955


100%|██████████| 38/38 [00:07<00:00,  4.91it/s]


Epochs: 15, Accuracy: 0.8544407894736842, Loss: 1.6110210418701172
Epochs: 15, Val_Accuracy: 0.8828125, Val_Loss: 1.5685027837753296


100%|██████████| 38/38 [00:07<00:00,  4.80it/s]


Epochs: 16, Accuracy: 0.8914473684210527, Loss: 1.5790038108825684
Epochs: 16, Val_Accuracy: 0.9661458333333334, Val_Loss: 1.5049437284469604


100%|██████████| 38/38 [00:08<00:00,  4.40it/s]


Epochs: 17, Accuracy: 0.9029605263157895, Loss: 1.559190273284912
Epochs: 17, Val_Accuracy: 0.9166666666666666, Val_Loss: 1.5474236011505127


100%|██████████| 38/38 [00:08<00:00,  4.68it/s]


Epochs: 18, Accuracy: 0.912828947368421, Loss: 1.5523298978805542
Epochs: 18, Val_Accuracy: 0.9270833333333334, Val_Loss: 1.5277385711669922


100%|██████████| 38/38 [00:08<00:00,  4.38it/s]


Epochs: 19, Accuracy: 0.9235197368421053, Loss: 1.53855299949646
Epochs: 19, Val_Accuracy: 0.9036458333333334, Val_Loss: 1.56006920337677


100%|██████████| 38/38 [00:08<00:00,  4.31it/s]


Epochs: 20, Accuracy: 0.9202302631578947, Loss: 1.5403207540512085
Epochs: 20, Val_Accuracy: 0.9270833333333334, Val_Loss: 1.5297139883041382


100%|██████████| 38/38 [00:08<00:00,  4.37it/s]


Epochs: 21, Accuracy: 0.9210526315789473, Loss: 1.5420230627059937
Epochs: 21, Val_Accuracy: 0.96875, Val_Loss: 1.4934837818145752


100%|██████████| 38/38 [00:08<00:00,  4.43it/s]


Epochs: 22, Accuracy: 0.9177631578947368, Loss: 1.5455228090286255
Epochs: 22, Val_Accuracy: 0.9348958333333334, Val_Loss: 1.5304040908813477


100%|██████████| 38/38 [00:08<00:00,  4.45it/s]


Epochs: 23, Accuracy: 0.9276315789473685, Loss: 1.534140944480896
Epochs: 23, Val_Accuracy: 0.9036458333333334, Val_Loss: 1.5513995885849


100%|██████████| 38/38 [00:09<00:00,  4.06it/s]


Epochs: 24, Accuracy: 0.9226973684210527, Loss: 1.537224531173706
Epochs: 24, Val_Accuracy: 0.9505208333333334, Val_Loss: 1.5141313076019287


100%|██████████| 38/38 [00:08<00:00,  4.29it/s]


Epochs: 25, Accuracy: 0.921875, Loss: 1.5410490036010742
Epochs: 25, Val_Accuracy: 0.9739583333333334, Val_Loss: 1.4864944219589233


100%|██████████| 38/38 [00:09<00:00,  4.21it/s]


Epochs: 26, Accuracy: 0.9564144736842105, Loss: 1.5060187578201294
Epochs: 26, Val_Accuracy: 0.9583333333333334, Val_Loss: 1.5059019327163696


100%|██████████| 38/38 [00:09<00:00,  4.18it/s]


Epochs: 27, Accuracy: 0.9333881578947368, Loss: 1.5253266096115112
Epochs: 27, Val_Accuracy: 0.9192708333333334, Val_Loss: 1.5350865125656128


100%|██████████| 38/38 [00:09<00:00,  4.10it/s]


Epochs: 28, Accuracy: 0.9366776315789473, Loss: 1.5246905088424683
Epochs: 28, Val_Accuracy: 0.984375, Val_Loss: 1.4792197942733765


100%|██████████| 38/38 [00:08<00:00,  4.41it/s]


Epochs: 29, Accuracy: 0.9564144736842105, Loss: 1.5050941705703735
Epochs: 29, Val_Accuracy: 0.96875, Val_Loss: 1.4970831871032715


100%|██████████| 38/38 [00:08<00:00,  4.23it/s]


Epochs: 30, Accuracy: 0.953125, Loss: 1.5100843906402588
Epochs: 30, Val_Accuracy: 0.9609375, Val_Loss: 1.5009663105010986


In [13]:
torch.save(model.state_dict(),"weights.pth")