In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision. transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from efficientnet_pytorch import EfficientNet
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

#  Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#  Data Preprocessing & Augmentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),  # Ensure 3 channels for EfficientNet
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

#  Load MRI Dataset (Replace with actual dataset path)
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
val_dataset = datasets.ImageFolder(root=test_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

#  Load Pretrained EfficientNet-B0 & Modify for 4 Classes
class BrainTumorModel(nn.Module):
    def __init__(self, num_classes=4):
        super(BrainTumorModel, self).__init__()
        self.model = EfficientNet.from_pretrained("efficientnet-b0")

        # Freeze base model layers
        for param in self.model.parameters():
            param.requires_grad = False

        # Modify final layer for 4 classes
        self.model._fc = nn.Sequential(
            nn.Linear(1280, 128),  # Reduce feature size
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes),  # 4 output neurons for 4 classes
            nn.Softmax(dim=1)  # Multi-class classification
        )

    def forward(self, x):
        return self.model(x)

model = BrainTumorModel(num_classes=4).to(device)

#  Loss Function & Optimizer
criterion = nn.CrossEntropyLoss()  # Multi-class loss function
optimizer = optim.Adam(model.model._fc.parameters(), lr=1e-3)

#  Training Loop
num_epochs = 1
train_losses, val_losses = [], []

for epoch in tqdm(range(num_epochs)):
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)  # No need for unsqueeze(1)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)
    train_losses.append(train_loss)

    # Validation Phase
    model.eval()
    val_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Convert softmax outputs to predicted class index
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_loss /= len(val_loader)
    val_losses.append(val_loss)

    val_acc = correct / total * 100
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

#  Plot Loss Curves
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

#  Save the Model
torch.save(model.state_dict(), "brain_tumor_model_4_classes.pth")