In [29]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Normalized image size
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = ImageFolder(root='/Users/coltenrodriguez/Desktop/dataset2/train', transform=transform)
val_dataset = ImageFolder(root='/Users/coltenrodriguez/Desktop/dataset2/val', transform=transform)

# data --> Pytorch Objects
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [30]:
import torch.nn as nn
import torchvision.models as models

# Initialize a resnet model
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # Assuming 2 classes: Constituent and Background

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [31]:
# Train the resnet model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

num_epochs = 25

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_dataset)
    print(f'Epoch {epoch}/{num_epochs} - Loss: {epoch_loss:.4f}')

Epoch 0/25 - Loss: 0.4361
Epoch 1/25 - Loss: 0.1334
Epoch 2/25 - Loss: 0.0446
Epoch 3/25 - Loss: 0.0462
Epoch 4/25 - Loss: 0.1779
Epoch 5/25 - Loss: 0.0530
Epoch 6/25 - Loss: 0.0731
Epoch 7/25 - Loss: 0.0407
Epoch 8/25 - Loss: 0.0352
Epoch 9/25 - Loss: 0.0132
Epoch 10/25 - Loss: 0.0235
Epoch 11/25 - Loss: 0.0234
Epoch 12/25 - Loss: 0.0098
Epoch 13/25 - Loss: 0.0024
Epoch 14/25 - Loss: 0.0103
Epoch 15/25 - Loss: 0.0076
Epoch 16/25 - Loss: 0.0021
Epoch 17/25 - Loss: 0.0007
Epoch 18/25 - Loss: 0.0002
Epoch 19/25 - Loss: 0.0001
Epoch 20/25 - Loss: 0.0014
Epoch 21/25 - Loss: 0.0004
Epoch 22/25 - Loss: 0.0002
Epoch 23/25 - Loss: 0.0005
Epoch 24/25 - Loss: 0.0002


In [41]:
torch.save(model.state_dict(), 'Preclassifier_model.pth')

In [43]:
from PIL import Image 

# Test the trained model. Can use files from /val for this
model.load_state_dict(torch.load('Preclassifier_model.pth'))
model.eval()

def classify_image(image_path):
    image = Image.open(image_path)
    image = transform(image).unsqueeze(0)
    image = image.to(device)

    with torch.no_grad():
        outputs = model(image)
        _, preds = torch.max(outputs, 1)
        class_names = train_dataset.classes
        print(class_names)
        return class_names[preds.item()]

print(classify_image('/Users/coltenrodriguez/Downloads/Img0026.tif'))

['Background', 'Constituent']
Constituent
