In [1]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [7]:
train_dir = '../Identify/Data/training'
validation_dir = '../Identify/Data/validation' 

# Define parameters
img_height, img_width = 224, 224
batch_size = 32

# Data augmentation and normalization for training
train_transforms = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

validation_transforms = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load datasets
train_dataset = ImageFolder(train_dir, transform=train_transforms)
validation_dataset = ImageFolder(validation_dir, transform=validation_transforms)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

In [3]:
import torch.nn as nn
import torchvision.models as models

# Load the pre-trained ResNet model
model = models.resnet18(pretrained=True)

# Freeze the base model
for param in model.parameters():
    param.requires_grad = False

# Replace the final fully connected layer
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)  # Binary classification

# Move the model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)



In [4]:
import torch.optim as optim
from torch.utils.data.sampler import SubsetRandomSampler

# Loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

# Training the model
num_epochs = 15

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

    # Evaluate the model on the validation set
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in validation_loader:
            inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            predicted = (torch.sigmoid(outputs) > 0.5).int()
            total += labels.size(0)
            correct += (predicted == labels.int()).sum().item()
    
    val_accuracy = correct / total
    print(f'Validation Loss: {val_loss/len(validation_loader):.4f}, Accuracy: {val_accuracy:.4f}')


Epoch [1/15], Loss: 0.5337
Validation Loss: 0.4361, Accuracy: 0.8270
Epoch [2/15], Loss: 0.3890
Validation Loss: 0.3753, Accuracy: 0.8580
Epoch [3/15], Loss: 0.3392
Validation Loss: 0.3402, Accuracy: 0.8610
Epoch [4/15], Loss: 0.3023
Validation Loss: 0.3286, Accuracy: 0.8730
Epoch [5/15], Loss: 0.2729
Validation Loss: 0.3177, Accuracy: 0.8770
Epoch [6/15], Loss: 0.2628
Validation Loss: 0.2959, Accuracy: 0.8840
Epoch [7/15], Loss: 0.2373
Validation Loss: 0.3088, Accuracy: 0.8810
Epoch [8/15], Loss: 0.2388
Validation Loss: 0.2814, Accuracy: 0.8890
Epoch [9/15], Loss: 0.2467
Validation Loss: 0.2932, Accuracy: 0.8830
Epoch [10/15], Loss: 0.2307
Validation Loss: 0.2836, Accuracy: 0.8870
Epoch [11/15], Loss: 0.2323
Validation Loss: 0.3162, Accuracy: 0.8740
Epoch [12/15], Loss: 0.2067
Validation Loss: 0.2825, Accuracy: 0.8920
Epoch [13/15], Loss: 0.2072
Validation Loss: 0.2884, Accuracy: 0.8820
Epoch [14/15], Loss: 0.1990
Validation Loss: 0.2687, Accuracy: 0.8950
Epoch [15/15], Loss: 0.1942
V

In [5]:
# Save the model
torch.save(model.state_dict(), 'binary_classification_model.pth')

# Load the model
model.load_state_dict(torch.load('binary_classification_model.pth'))
model = model.to(device)
