In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 as required by ResNet
    transforms.Grayscale(num_output_channels=3),  # Convert single-channel to 3-channel
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalization
])

In [4]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

In [5]:
model = models.resnet18(pretrained=True)



In [6]:
for param in model.parameters():
    param.requires_grad = False

In [7]:
# Get the number of input features for the fully connected layer
num_ftrs = model.fc.in_features  # Access the in_features from the original fc layer

# Replace the final fully connected layer with a new one
num_classes = 10  # MNIST has 10 classes (digits 0-9)
model.fc = nn.Sequential(
    nn.Linear(num_ftrs, 256),  # Use the correct input size for the first linear layer
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, num_classes)
)


In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)

In [None]:
num_epochs = 5

for epoch in range(num_epochs):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")