In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import pickle

# Load dataset and dataloader
with open("dataset.pkl", "rb") as f:
    dataset = pickle.load(f)

with open("dataloader.pkl", "rb") as f:
    dataloader = pickle.load(f)

# Define CNN model (Pre-trained ResNet)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(set(dataset.labels)))  # Adjust for the number of classes
model = model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function
def train_model(model, dataloader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss={total_loss:.4f}")

    # Save the trained model
    torch.save(model.state_dict(), "histology_model.pth")
    print("Model saved!")

# Train the model
train_model(model, dataloader, criterion, optimizer, epochs=10)
