Import dataset for classification, Load the saved model and train

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
import timm
import numpy as np

In [15]:

# %% Paths
train_data_path = r'C:\Users\mithu\Desktop\VIT Projects\III year\AI Project\Chest\train'
test_data_path = r'C:\Users\mithu\Desktop\VIT Projects\III year\AI Project\Chest\valid'
resnet_weights_path = r'C:\Users\mithu\Desktop\VIT Projects\III year\AI Project\fine_tuned_model_weights.pth'

train_transforms = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


In [16]:

train_dataset = datasets.ImageFolder(root=train_data_path, transform=train_transforms)
test_dataset = datasets.ImageFolder(root=test_data_path, transform=test_transforms)


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [17]:
# %% Hybrid Model
class HybridResNetViTTiny(nn.Module):
    def __init__(self, num_classes=4, patch_size=2, resnet_weights=None):
        super(HybridResNetViTTiny, self).__init__()

        # Load custom fine-tuned ResNet-18
        resnet = models.resnet18(pretrained=False)
        if resnet_weights:
            state_dict = torch.load(resnet_weights)
            state_dict = {k: v for k, v in state_dict.items() if 'fc' not in k}
            resnet.load_state_dict(state_dict, strict=False)

        self.backbone = nn.Sequential(*list(resnet.children())[:-2])  # [B, 512, H, W]

        # Patchify + Project to ViT input
        self.patch_size = patch_size
        self.vit_dim = 192  # vit_tiny dim
        self.projector = nn.Linear(512 * patch_size * patch_size, self.vit_dim)

        # Use ViT-tiny encoder blocks
        vit_tiny = timm.create_model('vit_tiny_patch16_224', pretrained=True)
        self.vit_blocks = vit_tiny.blocks
        self.norm = vit_tiny.norm

        # Classifier
        self.classifier = nn.Linear(self.vit_dim, num_classes)

    def forward(self, x):
        B = x.size(0)
        features = self.backbone(x)  # [B, 512, H, W]
        _, C, H, W = features.shape

        # Patchify
        patches = features.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(B, C, -1, self.patch_size * self.patch_size)
        patches = patches.permute(0, 2, 1, 3).reshape(B, -1, C * self.patch_size * self.patch_size)

        # Linear projection
        x_proj = self.projector(patches)  # [B, num_patches, 192]

        # ViT encoding
        for blk in self.vit_blocks:
            x_proj = blk(x_proj)

        x_proj = self.norm(x_proj)
        x_cls = x_proj.mean(dim=1)  # mean pooling
        out = self.classifier(x_cls)
        return out


In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HybridResNetViTTiny(
    num_classes=len(train_dataset.classes),
    patch_size=8,
    resnet_weights=resnet_weights_path
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

  state_dict = torch.load(resnet_weights)


In [20]:
# %% Training Loop
num_epochs = 10
best_f1 = 0.0

for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

    # Evaluation
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())

    f1 = f1_score(all_labels, all_preds, average='weighted')
    print(f'F1 Score after Epoch {epoch+1}: {f1:.4f}')

    # Save model with best F1
    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), 'best_hybrid_vit_model.pth')
        print("✅ Saved model with best F1!")

# %% Final Accuracy
accuracy = sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels)
print(f'Final Accuracy: {accuracy:.4f}')

Epoch [1/10], Loss: 0.6871
F1 Score after Epoch 1: 0.5634
✅ Saved model with best F1!
Epoch [2/10], Loss: 0.4101
F1 Score after Epoch 2: 0.6776
✅ Saved model with best F1!
Epoch [3/10], Loss: 0.0444
F1 Score after Epoch 3: 0.7864
✅ Saved model with best F1!
Epoch [4/10], Loss: 0.0232
F1 Score after Epoch 4: 0.7864
Epoch [5/10], Loss: 0.8750
F1 Score after Epoch 5: 0.8374
✅ Saved model with best F1!
Epoch [6/10], Loss: 0.4055
F1 Score after Epoch 6: 0.7696
Epoch [7/10], Loss: 0.0208
F1 Score after Epoch 7: 0.8356
Epoch [8/10], Loss: 0.0015
F1 Score after Epoch 8: 0.9181
✅ Saved model with best F1!
Epoch [9/10], Loss: 0.0016
F1 Score after Epoch 9: 0.8360
Epoch [10/10], Loss: 0.3181
F1 Score after Epoch 10: 0.8776
Final Accuracy: 0.8750
