In [4]:
import torch 
import os 
import numpy as np
import torch.nn as nn 
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets,transforms
import mlflow 
import mlflow.pytorch



In [5]:
class Config:
    EPOCHS=10
    BATCH_SIZE=32 
    LR=0.01 
    DEVICE="cuda" if torch.cuda.is_available() else "cpu"
    GAMMA=0.7 
    SEED=42
    LOG_INTERVAL=10
    TEST_BATCH_SIZE=1000
    DRY_RUN=True


In [6]:
config=Config()

In [7]:
from torch import nn
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet,self).__init__()
        self.conv1=nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3,stride=1)
        self.conv2=nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=1)
        self.dropout1=nn.Dropout(0.25)
        self.dropout2=nn.Dropout(0.5)
        self.fc1=nn.Linear(in_features=9216,out_features=128)
        self.fc2=nn.Linear(in_features=128,out_features=10)

    def forward(self,x):
        x=self.conv1(x)
        x=F.relu(x)
        x=self.conv2(x)
        x=F.relu(x)
        x=F.max_pool2d(x,2)
        x=self.dropout1(x)
        x=torch.flatten(x,1)
        x=self.fc1(x)
        x=F.relu(x)
        x=self.dropout2(x)
        x=self.fc2(x)

        output=F.log_softmax(x,dim=1)
        return output

In [27]:
def train(config,model,device,train_dataloader,optimizer,epoch):
    model.train()
    for batch_idx,(data,target) in enumerate(train_dataloader):
        data,target=data.to(device),target.to(device)
        optimizer.zero_grad()
        pred=model(data)
        loss=F.cross_entropy(pred,target)
        loss.backward()
        optimizer.step()
        if batch_idx % config.LOG_INTERVAL==0:
            print(f"train epoch: {epoch} [{batch_idx * len(data)}/{len(train_dataloader)} ({100.0*batch_idx/len(train_dataloader):.0f}]\t Loss: {loss.item():.6f}")

            if config.DRY_RUN:
                break

In [9]:
def test(config,model,device,test_dataloader):
    pass

In [10]:
torch.manual_seed(config.SEED)

<torch._C.Generator at 0x7faff5508df0>

In [11]:
train_kwargs={"batch_size":config.BATCH_SIZE}
test_kwargs={"batch_size":config.TEST_BATCH_SIZE}

In [12]:
if config.DEVICE=="cuda":
    cuda_kwargs={"num_workers":1,"pin_memory":True,"suffle":True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

In [13]:
transforms=transforms.Compose([
    transforms.ToTensor() 
])

In [23]:
train=datasets.MNIST("../data",train=True,download=True,transform=transforms)
test=datasets.MNIST("../data",transform=transforms)

In [24]:
train_loader=torch.utils.data.DataLoader(train,**train_kwargs)
test_loader=torch.utils.data.DataLoader(test,**test_kwargs)

In [25]:
image_batch,label_batch=next(iter(train_loader))
image_batch.shape,label_batch.shape

(torch.Size([32, 1, 28, 28]), torch.Size([32]))

In [26]:
model=ConvNet().to(config.DEVICE)
optimizer=torch.optim.Adam(model.parameters(),lr=config.LR)
schedular=StepLR(optimizer,step_size=1,gamma=config.GAMMA)

In [28]:
#training loop
for epoch in range(1,config.EPOCHS+1):
    train(config,model,config.DEVICE,train_loader,optimizer,epoch)
    schedular.step()

train epoch: 1 [0/1875 (0]	 Loss: 2.309456
train epoch: 2 [0/1875 (0]	 Loss: 9.522303
train epoch: 3 [0/1875 (0]	 Loss: 5.975313
train epoch: 4 [0/1875 (0]	 Loss: 2.737253
train epoch: 5 [0/1875 (0]	 Loss: 2.340699
train epoch: 6 [0/1875 (0]	 Loss: 2.279648
train epoch: 7 [0/1875 (0]	 Loss: 2.300728
train epoch: 8 [0/1875 (0]	 Loss: 2.298287
train epoch: 9 [0/1875 (0]	 Loss: 2.303431
train epoch: 10 [0/1875 (0]	 Loss: 2.301057


In [20]:
len(train_loader.dataset)

60000

In [34]:
with mlflow.start_run() as run:
    mlflow.pytorch.log_model(model,"model")
    model_path=mlflow.get_artifact_uri("model")
    loaded_torch_model=mlflow.pytorch.load_model(model_path)
    model.eval()
    with torch.inference_mode():
        test_datapoints,test_target=next(iter(test_loader))
        pred=model(test_datapoints[0].unsqueeze(0).to(config.DEVICE))
        actual=test_target[0].item()
        predicted=torch.argmax(pred).item()
        print(f"actual:{actual},predicted:{predicted}")

actual:5,predicted:5
