<a href="https://colab.research.google.com/github/AIEnthusiasts/CS7641-Assignment-1-Summer-/blob/main/CS7641_Assignment_1_MNIST_NN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)


def train_model(net, trainloader, testloader, criterion, optimizer, epochs=1):
    train_acc_history = []
    test_acc_history = []
    train_loss_history = []
    test_loss_history = []

    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        net.train()
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            train_loss = running_loss / (i + 1)
            train_acc = correct / total
            train_loss_history.append(train_loss)
            train_acc_history.append(train_acc)

            # Evaluate on test data after each batch
            net.eval()
            test_loss = 0.0
            correct = 0
            total = 0
            with torch.no_grad():
                for test_data in testloader:
                    test_images, test_labels = test_data
                    test_outputs = net(test_images)
                    loss = criterion(test_outputs, test_labels)
                    test_loss += loss.item()
                    _, test_predicted = torch.max(test_outputs.data, 1)
                    total += test_labels.size(0)
                    correct += (test_predicted == test_labels).sum().item()

            test_loss = test_loss / len(testloader)
            test_acc = correct / total
            test_loss_history.append(test_loss)
            test_acc_history.append(test_acc)

            print(f'Epoch {epoch + 1}, Batch {i + 1}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

    return train_loss_history, train_acc_history, test_loss_history, test_acc_history

train_loss_history, train_acc_history, test_loss_history, test_acc_history = train_model(net, trainloader, testloader, criterion, optimizer, epochs=1)



In [None]:
def plot_training_history(train_loss, train_acc, test_loss, test_acc):
    epochs = range(1, len(train_loss) + 1)

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, 'b', label='Training loss')
    plt.plot(epochs, test_loss, 'r', label='Validation loss')
    plt.title('Training and validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_acc, 'b', label='Training accuracy')
    plt.plot(epochs, test_acc, 'r', label='Validation accuracy')
    plt.title('Training and validation accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()

plot_training_history(train_loss_history, train_acc_history, test_loss_history, test_acc_history)
