In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.datasets as datasets # For ImageNet (if used)
from PIL import Image # For handling image data

# 1. Model Definition (Base Architecture - Simplified OverFeat)
class OverFeatBase(nn.Module):
    def __init__(self, num_classes=1000): #  ImageNet classes by default
        super(OverFeatBase, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(96, 256, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(256, 512, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 1024, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(1024, 1024, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        # Adaptive pooling to handle variable input sizes during feature extraction
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))  # Adjust based on your feature extraction point

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        return x

# 2. Data Loading and Preprocessing

# For ImageNet (if pretraining on ImageNet)
transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root="path/to/imagenet/train", transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

# 3. Pre-training (Simplified - assuming you train the entire base)
overfeat = OverFeatBase(num_classes=1000) # for ImageNet
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
overfeat = overfeat.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(overfeat.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) # Step decay

num_epochs = 90  # Adjust based on needs

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = overfeat(images)
        outputs = overfeat.classifier(outputs)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    scheduler.step()  # Update learning rate

    print (f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

print("Finished Pre-training")

# 4. Freeze Layers for Feature Extraction

for param in overfeat.features.parameters(): # Freeze convolutional layers
    param.requires_grad = False

# 5. Define Feature Extraction Dataset and New Classifier

# Define a dataset for your bird species data (replace with your data loading)
class BirdDataset(Dataset): # Example dataset
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

transform_bird = transforms.Compose([ # Transform for new dataset
    transforms.Resize(256),
    transforms.CenterCrop(224), # Consistent size with OverFeat
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Same as ImageNet
])

# Create a bird dataset (replace paths and labels)
image_paths = ["path/to/bird1.jpg", "path/to/bird2.jpg", ...]
labels = [0, 1, ...]
bird_dataset = BirdDataset(image_paths, labels, transform=transform_bird)

bird_loader = DataLoader(bird_dataset, batch_size=32, shuffle=True)

# 6. New Classifier (replacing fully connected layers from original OverFeat network)

num_bird_classes = len(set(labels)) # Number of bird classes in dataset

# OverFeat with New Classifier
class OverFeatForBirds(nn.Module):
    def __init__(self, base_model, num_bird_classes):
        super(OverFeatForBirds, self).__init__()
        self.features = base_model.features
        self.avgpool = base_model.avgpool

        # Simplified classifier:
        self.classifier = nn.Sequential(
            nn.Linear(1024 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, num_bird_classes) # Bird class number
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Replace with your number of bird species
overfeat_birds = OverFeatForBirds(overfeat, num_bird_classes).to(device) # Bring overfeat network to gpu
#7. Training the New Classifier

criterion_birds = nn.CrossEntropyLoss()
# Only train the new classifier parameters
optimizer_birds = optim.SGD(overfeat_birds.classifier.parameters(), lr=0.001, momentum=0.9)

num_epochs_birds = 50 # Adjust

for epoch in range(num_epochs_birds):
    for i, (images, labels) in enumerate(bird_loader):
        images = images.to(device)
        labels = labels.to(device)
        #Forward pass - notice that it is new overfeat model which has old base layers and new final classifying layer
        outputs = overfeat_birds(images) # Run both feature extract + new classifier
        loss = criterion_birds(outputs, labels)

        # Backward and optimize - only train the new classifier
        optimizer_birds.zero_grad()
        loss.backward()
        optimizer_birds.step()

    print (f'Epoch [{epoch+1}/{num_epochs_birds}], Loss: {loss.item():.4f}')

print("Finished Training New Classifier")