In [1]:
import torch
import torchvision
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import torch.optim as optim

In [2]:
means = [0.485, 0.456, 0.406]
stds = [0.229, 0.224, 0.225]

In [3]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(30),
        transforms.ToTensor(),
        transforms.Normalize(mean=means, std=stds),
    ]
)
val_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=means, std=stds),
    ]
)

In [4]:
import os
from torch.utils.data import Dataset


class WheatDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.classes = [x.split("_")[1] for x in sorted(os.listdir(data_dir))]
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        for label_dir in sorted(os.listdir(data_dir)):
            label_name = label_dir.split("_")[1]
            pth = os.path.join(data_dir, label_dir)
            for img_name in os.listdir(pth):
                self.image_paths.append(os.path.join(pth, img_name))
                self.labels.append(self.class_to_idx[label_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [5]:
from torch.utils.data import DataLoader

train_dataset = WheatDataset(
    data_dir="/home/sergei/Downloads/GrainSetData/wheat/train",
    transform=train_transforms,
)
val_dataset = WheatDataset(
    data_dir="/home/sergei/Downloads/GrainSetData/wheat/test", transform=val_transforms
)


In [41]:
len(train_dataset)/64

2812.5

In [6]:
import numpy as np


ar = np.array(train_dataset.labels)

In [7]:
WORKERS = 16

In [42]:
train_loader = DataLoader(
    train_dataset,
    batch_size=180,
    shuffle=True,
    num_workers=WORKERS,
    pin_memory=True,
    drop_last=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=180,
    shuffle=False,
    num_workers=WORKERS,
    pin_memory=True,
    drop_last=True,
)


In [43]:
import torch.nn as nn


class WheatConditionClassifier(nn.Module):
    def __init__(self, num_classes=8):
        super(WheatConditionClassifier, self).__init__()

        # Load pre-trained ResNet50
        self.base_model = models.resnet50(pretrained=True)

        # Freeze base model parameters
        for param in self.base_model.parameters():
            param.requires_grad = False

        # Remove original classifier

        # Add custom layers
        self.base_model.fc = nn.Linear(
            in_features=2048, out_features=num_classes, bias=True
        )

    def forward(self, x):
        return self.base_model(x)

In [44]:
model = WheatConditionClassifier()



In [46]:
learn_rate = 1.2e-3
wd = 1e-4

In [47]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.base_model.fc.parameters(),lr=learn_rate,weight_decay=wd)
model = nn.DataParallel(model)

In [48]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [49]:
model.to(device)

DataParallel(
  (module): WheatConditionClassifier(
    (base_model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       

In [50]:
len(train_loader)

1000

In [51]:
# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch={epoch}")
    # Training phase
    model.train()  # Set the model to training mode
    running_loss = 0.0
    correct = 0
    total = 0

    # Iterate over the training data
    for inputs, labels in train_loader:
        # Move data to the specified device (GPU or CPU)
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients to prevent accumulation
        optimizer.zero_grad()
        # Forward pass
        outputs = model(inputs)
        # Calculate the loss
        loss = criterion(outputs, labels)
        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Accumulate training loss
        running_loss += loss.item()
        # Get predictions
        _, predicted = torch.max(outputs.data, 1)
        # Update total and correct predictions
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        print("One iteration")

    # Calculate average training loss and accuracy
    train_loss = running_loss / len(train_loader)
    train_acc = correct / total

    # Validation phase
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    # Disable gradient calculation for validation
    with torch.no_grad():
        # Iterate over the validation data
        for inputs, labels in val_loader:
            # Move data to the specified device (GPU or CPU)
            inputs, labels = inputs.to(device), labels.to(device)
            # Forward pass
            outputs = model(inputs)
            # Calculate the loss
            loss = criterion(outputs, labels)

            # Accumulate validation loss
            val_loss += loss.item()
            # Get predictions
            _, predicted = torch.max(outputs.data, 1)
            # Update total and correct predictions
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    # Calculate average validation loss and accuracy
    val_loss /= len(val_loader)
    val_acc = val_correct / val_total

    # Print epoch results
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

Epoch=0
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration
One iteration


KeyboardInterrupt: 