<a href="https://colab.research.google.com/github/Parv-Agarwal/Internship-project/blob/main/Pre_Trained_Feature_Extractor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize to accommodate the convolutions
    transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels
    transforms.ToTensor()
])


In [4]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:04<00:00, 1.99MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 487kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:01<00:00, 1.51MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 5.27MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [5]:
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [6]:
# Feature extractor class
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 48, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.conv_to_single_channel = nn.Conv2d(48, 1, kernel_size=1)  # Reduce channels to 1
        self.adaptive_pool = nn.AdaptiveAvgPool2d((28, 28))

    def forward(self, x):
        x = self.conv(x)                        # Shape: (batch_size, 48, H, W)
        x = self.conv_to_single_channel(x)      # Shape: (batch_size, 1, H, W)
        x = self.adaptive_pool(x)               # Shape: (batch_size, 1, 28, 28)
        x = x.view(x.size(0), -1)               # Flatten -> Shape: (batch_size, 28 * 28)
        return x


In [7]:
# Classifier class
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(True),
            nn.Linear(256, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        return self.fc(x)

In [9]:
feature_extractor = FeatureExtractor().to(device)
classifier = Classifier().to(device)

# Loss and optimizer
criterion = nn.NLLLoss()
optimizer = optim.Adam(list(feature_extractor.parameters()) + list(classifier.parameters()), lr=0.001)

# Training loop
num_epochs = 10
total_step = len(train_loader)

In [10]:
for epoch in range(num_epochs):
    feature_extractor.train()
    classifier.train()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        features = feature_extractor(images)
        outputs = classifier(features)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# Save the feature extractor weights
torch.save(feature_extractor.state_dict(), 'pre_trained_feature_extractor_weights.pth')


Epoch [1/10], Step [100/938], Loss: 0.3368
Epoch [1/10], Step [200/938], Loss: 0.2967
Epoch [1/10], Step [300/938], Loss: 0.0489
Epoch [1/10], Step [400/938], Loss: 0.0949
Epoch [1/10], Step [500/938], Loss: 0.1043
Epoch [1/10], Step [600/938], Loss: 0.1573
Epoch [1/10], Step [700/938], Loss: 0.0907
Epoch [1/10], Step [800/938], Loss: 0.1074
Epoch [1/10], Step [900/938], Loss: 0.0277
Epoch [2/10], Step [100/938], Loss: 0.1325
Epoch [2/10], Step [200/938], Loss: 0.0464
Epoch [2/10], Step [300/938], Loss: 0.1763
Epoch [2/10], Step [400/938], Loss: 0.1241
Epoch [2/10], Step [500/938], Loss: 0.2299
Epoch [2/10], Step [600/938], Loss: 0.0841
Epoch [2/10], Step [700/938], Loss: 0.0507
Epoch [2/10], Step [800/938], Loss: 0.0178
Epoch [2/10], Step [900/938], Loss: 0.0708
Epoch [3/10], Step [100/938], Loss: 0.0634
Epoch [3/10], Step [200/938], Loss: 0.0985
Epoch [3/10], Step [300/938], Loss: 0.1077
Epoch [3/10], Step [400/938], Loss: 0.0387
Epoch [3/10], Step [500/938], Loss: 0.0916
Epoch [3/10