In [None]:
from typing import List
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from tqdm import tqdm

In [None]:
class StoneClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, 3, 2, 1)
        self.conv2 = nn.Conv2d(8, 16, 3, 2, 1)
        self.conv3 = nn.Conv2d(16, 32, 3, 2, 1)
        self.fc1 = nn.Linear(32 * 3 * 3, 128)
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc3(x), dim=1)
        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
def train(dataset_path: str, output_path: str, device: torch.device, batch_size: int = 64, class_weights: List[int] = [1, 3]):
    transform = transforms.Compose(
        [
            transforms.CenterCrop(24),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )

    data_train = ImageFolder(dataset_path + "/train/", transform)
    data_val = ImageFolder(dataset_path + "/valid/", transform)
    train_loader = torch.utils.data.DataLoader(data_train,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               num_workers=4)
    val_loader = torch.utils.data.DataLoader(data_val,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             drop_last=False,
                                             num_workers=4)

    model = StoneClassifier().to(device)
    criterion = nn.CrossEntropyLoss(torch.FloatTensor([1, 3]).to(device))
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    classes = ["nostone", "stone"]

    for epoch in range(100):
        running_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data

            optimizer.zero_grad()

            outputs = model(inputs.to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0

        correct_pred = {class_name: 0 for class_name in classes}
        total_pred = {class_name: 0 for class_name in classes}
        with torch.no_grad():
            for i, data in enumerate(val_loader):
                inputs, labels = data
                labels = labels.to(device)

                outputs = model(inputs.to(device))
                _, predictions = torch.max(outputs, 1)
                for label, prediction in zip(labels, predictions.to(device)):
                    if label == prediction:
                        correct_pred[classes[label]] += 1
                    total_pred[classes[label]] += 1

        for class_name, correct_count in correct_pred.items():
            accuracy = 100 * float(correct_count) / total_pred[class_name]
            print(f'Accuracy for class: {class_name:5s} is {accuracy:.1f} %')
        nostone_accuracy = correct_pred["nostone"] / total_pred["nostone"]
        stone_accuracy = correct_pred["stone"] / total_pred["stone"]
        torch.save(model.state_dict(), output_path + f"/{epoch}_{round(nostone_accuracy, 2)}_{round(stone_accuracy, 2)}.pth")

    print('Finished Training')

In [None]:
train("../surviv-rl/dataset/", "../weights/", device)