In [1]:
from torchvision import transforms
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder

In [8]:
transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor()
])

data = ImageFolder(root="/home/rynutty/Documents/ProgrammingProjects/CustomModels/datasets/CropDisease", transform=transform)
perm = torch.randperm(len(data))

cutoff_idx = int(len(perm) * 0.8)

train_indices = perm[:cutoff_idx]
eval_indices = perm[cutoff_idx:]

train_subset = Subset(data, train_indices)
eval_subset = Subset(data, eval_indices)

train_dataloader = DataLoader(train_subset, batch_size=48, shuffle=True)
eval_dataloader = DataLoader(eval_subset, batch_size=64, shuffle=True)

In [133]:
class CropDiseaseClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        self.res_block1 = ResidualBlock(in_channels=3, out_channels=18)
        self.res_block2 = ResidualBlock(in_channels=18, out_channels=36)
        self.res_block3 = ResidualBlock(in_channels=36, out_channels=72)
        self.ds_block1 = ResidualBlock(in_channels=72, out_channels=142, stride=2)
        self.ds_block2 = ResidualBlock(in_channels=142, out_channels=284, stride=2)
        self.ds_block3 = ResidualBlock(in_channels=284, out_channels=568, stride=2)

        dummy_ex = torch.randn((1, 3, 224, 224))
        dummy_ex = self.res_block3(self.res_block2(self.res_block1(dummy_ex)))
        dummy_ex = self.ds_block3(self.ds_block2(self.ds_block1(dummy_ex)))
        flattened = torch.flatten(dummy_ex, start_dim=1)

        self.fc = nn.Linear(in_features=flattened.size(1), out_features=23)

    def forward(self, x):
        x = self.res_block3(self.res_block2(self.res_block1(x)))
        x = self.ds_block3(self.ds_block2(self.ds_block1(x)))
        x = torch.flatten(x, start_dim=1)
        return self.fc(x)
    


class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1)
        self.a = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)

        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, stride=1, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)


        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(num_features=out_channels)
            )
        else:
            self.shortcut = nn.Identity()


    def forward(self, x):
        identity = self.shortcut(x)
        x = self.a(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += identity
        return self.a(x)


In [134]:
model = CropDiseaseClassifier()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
loss_fn = nn.CrossEntropyLoss()

In [135]:
epochs = 5

for epoch in range(1, epochs+1):

    print("Epoch", epoch, "in progress...")

    train_working_loss = 0
    train_total = 0
    train_correct= 0

    print("\nTraining...")

    model.train()
    for image_batch, label_batch in train_dataloader:
        optimizer.zero_grad()
        logits = model(image_batch)
        loss = loss_fn(logits, label_batch)
        loss.backward()
        optimizer.step()

        train_working_loss += loss.item()
        train_total += len(image_batch)
        predicted = logits.max(1)
        train_correct += (predicted.indices == label_batch).sum().item()

        print(f"Avg training loss: {train_working_loss / train_total}, Acc: {train_correct / train_total}")

    eval_working_loss = 0
    eval_total = 0
    eval_correct = 0

    print("\nEvalutaing...")

    model.eval()
    with torch.no_grad():
        for image_batch, label_batch in eval_dataloader:
            logits = model(image_batch)
            loss = loss_fn(logits, label_batch)

            eval_working_loss += loss.item()
            eval_total += len(image_batch)
            predicted = logits.max(1)
            eval_correct += (predicted.indices == label_batch).sum().item()

    print(f"Avg eval loss: {eval_working_loss / eval_total}, Acc: {eval_correct / eval_total}")


Epoch 1 in progress...

Training...
Avg training loss: 0.06500361363093059, Acc: 0.08333333333333333
Avg training loss: 1.9073848227659862, Acc: 0.07291666666666667


KeyboardInterrupt: 