In [3]:
# basic imports
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms


In [4]:
# cuda availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# transform to imagenet format
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [None]:
# dataset loading (xrays?)
# skip this cell for the time being until we have a dataest
train_dataset = datasets.ImageFolder("/dataset/train", transform=transform)
val_dataset = datasets.ImageFolder("/dataset/val", transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

In [5]:
# Pretrained, SOTA resnet model for imagenet
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# freezing all layers in the model
for param in model.parameters():
    param.requires_grad = False

# Replacing number of classes to classify into, to our desired number
num_classes = 5 # to be decided based on dataset
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
model = model.to(device)


In [6]:
# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

In [None]:
# fine tuning step
# Get to this cell once we have decided on a dataset

# play around with this number
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Validation Accuracy
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Validation Accuracy: {100 * correct / total:.2f}%")


In [None]:
# Inference logic for our fine tuned model
# TO BE FIGURED OUT, THIS IS JUST TEMPORARY AND UNTESTED

from PIL import Image

def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)
    return image

def predict(image_path, model, class_names):
    model.eval()  # Ensure the model is in evaluation mode
    image = preprocess_image(image_path).to(device)

    with torch.no_grad():
        outputs = model(image)
        _, predicted_idx = outputs.max(1)
        predicted_class = class_names[predicted_idx[0].item()]

    return predicted_class

class_names = ["class0", "class1", "class2", ...]  # replace with your class names

image_path = "path_to_your_image.jpg"
predicted_class = predict(image_path, model, class_names)

print(f"The predicted class is: {predicted_class}")



In [2]:
# Inference on regular resnet:
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image("camera.jpeg")

# initializing model with weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()

# preprocessing
preprocess = weights.transforms()
batch = preprocess(img).unsqueeze(0)

# getting result category as predicted
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

reflex camera: 22.1%
