In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import random
from tqdm import tqdm

In [2]:
class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        
        # Encoder
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),  # 112x112
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # 56x56
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # 28x28
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),  # 14x14
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 512, 3, stride=2, padding=1),  # 7x7
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, 7, stride=1, padding=0),  # 1x1
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        
        self.encoder_linear = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 512),
            nn.ReLU()
        )
        
        # Decoder
        self.decoder_linear = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Unflatten(1, (512, 1, 1))
        )
        
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 7, stride=1, padding=0),  # 7x7
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),  # 14x14
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # 28x28
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 56x56
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 112x112
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),  # 224x224
            nn.Tanh()
        )

    def encode(self, x):
        x = self.encoder_conv(x)
        return self.encoder_linear(x)

    def decode(self, x):
        x = self.decoder_linear(x)
        return self.decoder_conv(x)

    def forward(self, x):
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        return decoded


In [3]:
# Data loading and preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

train_dataset = datasets.ImageFolder(root='/home/dl_class/data/NEA/NEUdata_split/Train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = datasets.ImageFolder(root='/home/dl_class/data/NEA/NEUdata_split/Test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [4]:
# Model, loss, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvAutoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

In [12]:
# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}', leave=False):
        img, _ = batch
        img = img.to(device)

        optimizer.zero_grad()
        output = model(img)
        loss = criterion(output, img)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

                                                                                                                                                                                                                   

In [13]:
# Generate reconstructions
model.eval()
test_images = []
reconstructed_images = []

with torch.no_grad():
    for i, (img, _) in enumerate(test_loader):
        if i >= 50:  # We only need 50 images
            break
        img = img.to(device)
        output = model(img)
        test_images.extend(img.cpu())
        reconstructed_images.extend(output.cpu())

test_images = torch.stack(test_images[:50])
reconstructed_images = torch.stack(reconstructed_images[:50])

In [7]:
# Load and modify the pretrained classifier
import torchvision.models as models

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

model_ft = models.resnet18(pretrained=True)
set_parameter_requires_grad(model_ft,True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(train_dataset.classes))
model_ft = model_ft.to(device)

# Train the classifier
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Train the classifier 
num_epochs = 10
for epoch in range(num_epochs):
    model_ft.train()
    for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}', leave=False):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer_ft.zero_grad()
        outputs = model_ft(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_ft.step()


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/durga/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 44.7M/44.7M [00:04<00:00, 11.7MB/s]
                                                                                                                                                                                                                   

In [14]:
# Evaluate on original and reconstructed images
model_ft.eval()
correct_original = 0
correct_reconstructed = 0

with torch.no_grad():
    for i in range(50):
        # Original images
        outputs = model_ft(test_images[i].unsqueeze(0).to(device))
        _, predicted = torch.max(outputs, 1)
        correct_original += (predicted == test_dataset.targets[i]).sum().item()

        # Reconstructed images
        outputs = model_ft(reconstructed_images[i].unsqueeze(0).to(device))
        _, predicted = torch.max(outputs, 1)
        correct_reconstructed += (predicted == test_dataset.targets[i]).sum().item()

accuracy_original = correct_original / 50
accuracy_reconstructed = correct_reconstructed / 50

print(f"Accuracy on original images: {accuracy_original:.4f}")
print(f"Accuracy on reconstructed images: {accuracy_reconstructed:.4f}")

Accuracy on original images: 1.0000
Accuracy on reconstructed images: 0.0000
