In [None]:
# Importing dependencies
import torch
from PIL import Image
from torch import nn,save,load
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Loading Data
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root="data", download=True, train=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# Define the image classifier model
class ImageClassifier(nn.Module):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU()
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 22 * 22, 10)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x
# Create an instance of the image classifier model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classifier = ImageClassifier().to(device)
# Define the optimizer and loss function
optimizer = Adam(classifier.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
# Train the model
for epoch in range(10):  # Train for 10 epochs
    for images, labels in train_loader:
        images,labels=images.to(device), labels.to(device)
        optimizer.zero_grad()  # Reset gradients
        outputs = classifier(images)  # Forward pass
        loss = loss_fn(outputs, labels)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights

    print(f"Epoch:{epoch} loss is {loss.item()}")
# Save the trained model
torch.save(classifier.state_dict(), 'model_state.pt')
# Load the saved model
with open('model_state.pt', 'rb') as f:
     classifier.load_state_dict(load(f))

# Perform inference on an image
img = Image.open('/content/image.jpg')
img_transform = transforms.Compose([transforms.ToTensor()])
img_tensor = img_transform(img).unsqueeze(0).to(device)
output = classifier(img_tensor)
predicted_label = torch.argmax(output)
print(f"Predicted label: {predicted_label}")

Epoch:0 loss is 0.014564729295670986
Epoch:1 loss is 0.05569921433925629
Epoch:2 loss is 0.0010305619798600674
Epoch:3 loss is 0.06546627730131149
Epoch:4 loss is 0.03309878334403038
Epoch:5 loss is 2.277787461935077e-05
Epoch:6 loss is 7.010154513409361e-05
Epoch:7 loss is 1.1808187991846353e-05
Epoch:8 loss is 0.00016703254368621856
Epoch:9 loss is 5.0206301239086315e-05
Predicted label: 9
