In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader


In [2]:
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

for param in model.parameters():
    param.requires_grad = False

num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, 128),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, 2) 
)


In [3]:
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


In [4]:
train_dataset = datasets.ImageFolder(root='dataset/train', transform=train_transform)
val_dataset = datasets.ImageFolder(root='dataset/val', transform=test_val_transform)
test_dataset = datasets.ImageFolder(root='dataset/test', transform=test_val_transform)

trainloader = DataLoader(train_dataset, batch_size=100, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=100, shuffle=False)
testloader = DataLoader(test_dataset, batch_size=100, shuffle=False)


In [5]:
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [8]:
def train(model, trainloader, optimizer, criterion, device):
    model.train()
    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    return loss.item()

def validate(model, validationloader, device):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for data in validationloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

for epoch in range(20): 
    train_loss = train(model, trainloader, optimizer, criterion, device)
    validation_accuracy = validate(model, valloader, device)
    print(f'Epoch {epoch+1}, Loss: {train_loss:.4f}, Validation Accuracy: {validation_accuracy:.4f}')


Epoch 1, Loss: 0.6197, Validation Accuracy: 0.6837
Epoch 2, Loss: 0.5886, Validation Accuracy: 0.7653
Epoch 3, Loss: 0.5925, Validation Accuracy: 0.7143
Epoch 4, Loss: 0.5135, Validation Accuracy: 0.7551
Epoch 5, Loss: 0.5600, Validation Accuracy: 0.7347
Epoch 6, Loss: 0.5233, Validation Accuracy: 0.8163
Epoch 7, Loss: 0.4554, Validation Accuracy: 0.7857
Epoch 8, Loss: 0.6124, Validation Accuracy: 0.7551
Epoch 9, Loss: 0.4440, Validation Accuracy: 0.8265
Epoch 10, Loss: 0.4253, Validation Accuracy: 0.8776
Epoch 11, Loss: 0.4168, Validation Accuracy: 0.8878
Epoch 12, Loss: 0.4373, Validation Accuracy: 0.8673
Epoch 13, Loss: 0.3519, Validation Accuracy: 0.8980
Epoch 14, Loss: 0.3173, Validation Accuracy: 0.8061
Epoch 15, Loss: 0.3295, Validation Accuracy: 0.8163
Epoch 16, Loss: 0.2564, Validation Accuracy: 0.8878
Epoch 17, Loss: 0.3247, Validation Accuracy: 0.8776
Epoch 18, Loss: 0.2552, Validation Accuracy: 0.8878
Epoch 19, Loss: 0.2327, Validation Accuracy: 0.9184
Epoch 20, Loss: 0.236

In [9]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the test images: {100 * correct // total}%')


Accuracy of the network on the test images: 89%


In [10]:
import gc

model.cpu()
del model
gc.collect()
torch.cuda.empty_cache()

: 