In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from tqdm import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def calc_acc(preds,labels):
  _,pred_max=torch.max(preds,1)
  acc=torch.sum(pred_max==labels.data,dtype=torch.float64)/len(preds)
  return acc

In [None]:

class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Conv2d(3,512,(3,3),(1,1),(0,0))
    self.conv2=nn.Conv2d(512,128,(3,3),(1,1),(1,1))
    self.conv3=nn.Conv2d(128,64,(3,3),(1,1),(1,1))
    self.conv4=nn.Conv2d(64,32,(3,3),(1,1),(1,1))
    self.conv5=nn.Conv2d(32,16,(3,3),(1,1),(1,1))

    self.fc1=nn.Linear(16*1*1,128)
    self.fc2=nn.Linear(128,256)
    self.fc3=nn.Linear(256,10)
  
  def forward(self,x):
    x=F.relu(self.conv1(x))
    x=F.max_pool2d(x,kernel_size=(2,2))

    x=F.relu(self.conv2(x))
    x=F.max_pool2d(x,kernel_size=(2,2))

    x=F.relu(self.conv3(x))
    x=F.max_pool2d(x,kernel_size=(2,2))

    x=F.relu(self.conv4(x))
    x=F.max_pool2d(x,kernel_size=(2,2))

    x=F.relu(self.conv5(x))

    x=torch.flatten(x,start_dim=1)

    x=F.relu(self.fc1(x))
    x=F.relu(self.fc2(x))
    x=self.fc3(x)
    x=torch.softmax(x,dim=1)
    return x

In [None]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=Model().to(device)

In [None]:
#hyper parameters
batch_size=16
epochs=40
lr=0.0001


In [None]:
transform=transforms.Compose([
                                   transforms.RandomRotation(10),
                                   transforms.Resize((30,30)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225))

])

dataset=torchvision.datasets.ImageFolder(root="/content/drive/MyDrive/persian-mnist",transform=transform)
train_data_loader=torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [None]:
#compile 
optimizer=torch.optim.Adam(model.parameters(),lr=lr)
loss_function=nn.CrossEntropyLoss()

In [None]:
model.train()
for epoch in range(epochs):
  train_loss=0.0
  train_acc=0.0
  for images,labels in tqdm(train_data_loader):
    images,labels=images.to(device),labels.to(device)
    optimizer.zero_grad()

    preds=model(images)
    loss=loss_function(preds,labels)
    loss.backward()
    optimizer.step()
    train_loss+=loss
    train_acc+=calc_acc(preds,labels)
  
  total_loss=train_loss/len(train_data_loader)
  total_acc=train_acc/len(train_data_loader)
  print(f"epochs: {epoch} , loss: {total_loss}, acc: {total_acc}")

100%|██████████| 75/75 [00:03<00:00, 20.56it/s]


epochs: 0 , loss: 2.3027124404907227, acc: 0.08916666666666667


100%|██████████| 75/75 [00:03<00:00, 21.28it/s]


epochs: 1 , loss: 2.3025569915771484, acc: 0.10500000000000001


100%|██████████| 75/75 [00:03<00:00, 21.41it/s]


epochs: 2 , loss: 2.299149513244629, acc: 0.24666666666666667


100%|██████████| 75/75 [00:03<00:00, 21.28it/s]


epochs: 3 , loss: 2.238128423690796, acc: 0.18916666666666668


100%|██████████| 75/75 [00:03<00:00, 21.42it/s]


epochs: 4 , loss: 2.099407911300659, acc: 0.39


100%|██████████| 75/75 [00:03<00:00, 21.45it/s]


epochs: 5 , loss: 1.9515552520751953, acc: 0.5616666666666666


100%|██████████| 75/75 [00:03<00:00, 21.35it/s]


epochs: 6 , loss: 1.850368857383728, acc: 0.6466666666666667


100%|██████████| 75/75 [00:03<00:00, 21.63it/s]


epochs: 7 , loss: 1.8109608888626099, acc: 0.67


100%|██████████| 75/75 [00:03<00:00, 22.06it/s]


epochs: 8 , loss: 1.7867223024368286, acc: 0.6950000000000001


100%|██████████| 75/75 [00:03<00:00, 21.33it/s]


epochs: 9 , loss: 1.7752232551574707, acc: 0.7025


100%|██████████| 75/75 [00:03<00:00, 21.23it/s]


epochs: 10 , loss: 1.7669274806976318, acc: 0.7116666666666667


100%|██████████| 75/75 [00:03<00:00, 21.59it/s]


epochs: 11 , loss: 1.7734137773513794, acc: 0.7025


100%|██████████| 75/75 [00:03<00:00, 21.48it/s]


epochs: 12 , loss: 1.754185438156128, acc: 0.7166666666666667


100%|██████████| 75/75 [00:03<00:00, 21.33it/s]


epochs: 13 , loss: 1.7489073276519775, acc: 0.7200000000000001


100%|██████████| 75/75 [00:03<00:00, 20.91it/s]


epochs: 14 , loss: 1.7420146465301514, acc: 0.7300000000000001


100%|██████████| 75/75 [00:03<00:00, 20.89it/s]


epochs: 15 , loss: 1.739540457725525, acc: 0.7250000000000001


100%|██████████| 75/75 [00:03<00:00, 21.53it/s]


epochs: 16 , loss: 1.7267242670059204, acc: 0.7400000000000001


100%|██████████| 75/75 [00:03<00:00, 21.34it/s]


epochs: 17 , loss: 1.7301024198532104, acc: 0.7400000000000001


100%|██████████| 75/75 [00:03<00:00, 20.71it/s]


epochs: 18 , loss: 1.7244983911514282, acc: 0.7408333333333333


100%|██████████| 75/75 [00:03<00:00, 20.50it/s]


epochs: 19 , loss: 1.7127039432525635, acc: 0.7566666666666667


100%|██████████| 75/75 [00:03<00:00, 20.37it/s]


epochs: 20 , loss: 1.6941473484039307, acc: 0.7708333333333334


100%|██████████| 75/75 [00:03<00:00, 20.03it/s]


epochs: 21 , loss: 1.6938954591751099, acc: 0.7741666666666667


100%|██████████| 75/75 [00:03<00:00, 20.11it/s]


epochs: 22 , loss: 1.6927769184112549, acc: 0.7766666666666667


100%|██████████| 75/75 [00:03<00:00, 21.56it/s]


epochs: 23 , loss: 1.677361011505127, acc: 0.7908333333333334


100%|██████████| 75/75 [00:03<00:00, 21.24it/s]


epochs: 24 , loss: 1.6704503297805786, acc: 0.7975000000000001


100%|██████████| 75/75 [00:03<00:00, 21.52it/s]


epochs: 25 , loss: 1.6753350496292114, acc: 0.7866666666666667


100%|██████████| 75/75 [00:03<00:00, 21.40it/s]


epochs: 26 , loss: 1.6611829996109009, acc: 0.8091666666666667


100%|██████████| 75/75 [00:03<00:00, 19.34it/s]


epochs: 27 , loss: 1.65818452835083, acc: 0.81


100%|██████████| 75/75 [00:03<00:00, 20.33it/s]


epochs: 28 , loss: 1.6571049690246582, acc: 0.8091666666666667


100%|██████████| 75/75 [00:03<00:00, 20.87it/s]


epochs: 29 , loss: 1.6384624242782593, acc: 0.8333333333333334


100%|██████████| 75/75 [00:03<00:00, 20.71it/s]


epochs: 30 , loss: 1.6246086359024048, acc: 0.8466666666666667


100%|██████████| 75/75 [00:03<00:00, 20.47it/s]


epochs: 31 , loss: 1.6148911714553833, acc: 0.8691666666666668


100%|██████████| 75/75 [00:03<00:00, 20.61it/s]


epochs: 32 , loss: 1.5964601039886475, acc: 0.8791666666666668


100%|██████████| 75/75 [00:03<00:00, 20.53it/s]


epochs: 33 , loss: 1.5952109098434448, acc: 0.8825000000000001


100%|██████████| 75/75 [00:03<00:00, 20.93it/s]


epochs: 34 , loss: 1.5865421295166016, acc: 0.8841666666666668


100%|██████████| 75/75 [00:03<00:00, 20.96it/s]


epochs: 35 , loss: 1.581704020500183, acc: 0.8925000000000001


100%|██████████| 75/75 [00:03<00:00, 20.80it/s]


epochs: 36 , loss: 1.5878342390060425, acc: 0.885


100%|██████████| 75/75 [00:03<00:00, 20.78it/s]


epochs: 37 , loss: 1.576040267944336, acc: 0.8958333333333334


100%|██████████| 75/75 [00:03<00:00, 20.54it/s]


epochs: 38 , loss: 1.55534827709198, acc: 0.9208333333333334


100%|██████████| 75/75 [00:03<00:00, 20.58it/s]


epochs: 39 , loss: 1.5610859394073486, acc: 0.915


In [None]:
torch.save(model.state_dict(),"persian.pth")