In [94]:
import torch
import torchvision
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import numpy as np

In [95]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [96]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ]
)

dataset = ImageFolder("train/", transform= transform)

In [97]:
seed = 8912
torch.manual_seed(seed)

<torch._C.Generator at 0x1f7e131dc70>

In [98]:
val_size = 13500
train_size = len(dataset) - val_size

val_dataset, train_dataset = data.random_split(dataset, lengths=[val_size, train_size])

train_loader = data.DataLoader(train_dataset, batch_size=2000, shuffle=True, pin_memory=True, num_workers=2)
val_loader = data.DataLoader(val_dataset, batch_size=2000, pin_memory=True, num_workers=2)
print(len(val_dataset), len(train_dataset))

13500 74511


In [99]:
class ImageClassifierNetwork(nn.Module):

    def __init__(self):
        super().__init__()
        self.loss = nn.CrossEntropyLoss()

    def train_step(self, input, labels):

        input = input.to(device=device, non_blocking=True)
        labels = labels.to(device=device, non_blocking = True)

        preds = self.forward(input)
        loss = self.loss(preds, labels)

        return loss

    def val_step(self, input, labels):

        input = input.to(device=device, non_blocking=True)
        labels = labels.to(device=device, non_blocking = True)

        preds = self.forward(input)
        loss = self.loss(input, labels)
        accuracy = self._accuracy(preds, labels)

        return {"loss":loss, "accuracy":accuracy}


    def val_epoch_end(self, preformance_measurement_data):

        accuracy = [x["accuracy"].numpy() for x in preformance_measurement_data]
        avg_accuracy = np.mean(accuracy)

        loss = [x["loss"].numpy() for x in preformance_measurement_data]
        avg_loss = np.mean(loss)

        return {"loss":{avg_loss}, "accuracy": avg_accuracy}


    def _accuracy(self, preds, labels):
        batch_size = len(preds)

        pred_indices = torch.argmax(preds, dim=1)
        return torch.tensor(torch.sum(pred_indices == labels).item() / batch_size)





In [100]:
class ResNetBlock(nn.Module):
    """ RestNet block with BN and full pre-activation """
    def __init__(self, in_channels, out_channels_conv1, out_channels_conv2, kernel_size=[3, 3], stride=[1,1], padding=[1,1]):
        super().__init__()

        self.network = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels_conv1, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
            nn.Conv2d(out_channels_conv1, out_channels_conv2, kernel_size=kernel_size, stride=stride, padding=padding)
        )


    def forward(self, input):
        return self(input) + input




class ResNetBlockChangeDepth(ResNetBlock):
    """ RestNet block with dimention size reduction """
    def __init__(self, in_channels, out_channels_conv1, out_channels_conv2, res_conv_kernel_size=1, res_conv_stride=1, res_conv_padding=0, kernel_size=[3, 3], stride=[1,1], padding=[1,1]):

        super().__init__(self, in_channels, out_channels_conv1, out_channels_conv2, kernel_size=[3, 3], stride=[1,1], padding=[1,1])

        self.conv = nn.Conv2d(in_channels, out_channels_conv2, kernel_size=res_conv_kernel_size, stride=res_conv_stride, padding=res_conv_padding)

    def forward(self, input):
        return self(input) + self.conv(input)



class Wrapper(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)


In [101]:
class ResNet(ImageClassifierNetwork):

    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.resnet_block1 = ResNetBlock()
        self.resnet_block2 = ResNetBlockChangeDepth()
        self.resnet_block3 = ResNetBlockChangeDepth()


        self.network = nn.Sequential(

            # Prep 3x64x64
            nn.Conv2d(in_channels, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 64x32x32

            nn.Conv2d(64, 128, kernel_size=5, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            Wrapper(self.resnet_block1),
            nn.MaxPool2d(2, 2), # 128x16x16

            Wrapper(self.resnet_block2),
            nn.MaxPool2d(2, 2), # 256x8x8

            Wrapper(self.resnet_block3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # 256x4x4

            nn.Flatten(),
            # Fully connected layer
            nn.Linear(4096, 2048),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(2048, num_classes)


        )


def forward(self, input):
    return self(input)


