In [1]:
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_sched
import torch.utils.data as data
import numpy as np
import pickle

In [2]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [3]:
transform_raw_2 = transforms.Compose([
    transforms.Resize([256, 256]),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
)

In [4]:
dataset = torchvision.datasets.ImageFolder("DS_SPECIES_CLASSIFICATION", transform=transform_raw_2)
len(dataset)

4361

In [5]:
species_dict = dataset.class_to_idx

with open('species_dict.pkl', 'wb') as file:
    pickle.dump(species_dict, file)

species_dict

{'Borowik Królewski JAD': 0,
 'Borowik szatański NJAD': 1,
 'Gołąbek fiołkowoniebiesk JAD': 2,
 'Gołąbek odbielony NJAD': 3,
 'Kolczak rdzawoczerwony NJAD': 4,
 'Lejkówka fałdowana JAD': 5,
 'Lejkówka żebrowana JAD': 6,
 'Maślak JAD': 7,
 'Muchomor Cesarski JAD': 8,
 'Muchomor Cytrynowy JAD': 9,
 'Muchomor Panterowy NAJD': 10,
 'Pieczarka JAD': 11,
 'Smardz jadalny JAD': 12,
 'Sromotnik bezwstydny NJAD': 13,
 'Strzępiak wiosenny NJAD': 14,
 'Wilgotnica JAD': 15,
 'Wilgotnica szerokoblaszkowa JAD': 16,
 'Zasłonak olszowy NJAD': 17,
 'Zasłoniak NW': 18,
 'Łysiczka oliwkowa NJAD': 19}

In [6]:
# classes_elem_cnt = {}
# for i in range(len(dataset)):

#     x = dataset[i][1]

#     if (x in classes_elem_cnt.keys()):
#         classes_elem_cnt[x] += 1

#     else:
#         classes_elem_cnt[x] = 1

#     # if (i%10==0):
# print(classes_elem_cnt)

# with open('classes_member_cnt.pkl', 'wb') as file:
#     pickle.dump(classes_elem_cnt, file)


In [7]:
# weights
with open('classes_member_cnt.pkl', 'rb') as file:
    classes_elem_cnt = pickle.load(file)

max_count = max(classes_elem_cnt.values())

# Oblicz wagi dla każdej klasy
weights = torch.tensor([max_count / classes_elem_cnt[i] for i in range(len(classes_elem_cnt))])


In [8]:
seed = 123
torch.manual_seed(seed)

<torch._C.Generator at 0x2705940d850>

In [9]:
train_size = 4000
val_size = len(dataset) - train_size

train_ds, val_ds = data.random_split(dataset, [train_size, val_size])

In [10]:
batch_size = 50

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=1)
val_loader = DataLoader(val_ds, batch_size=batch_size, pin_memory=True, num_workers=1)

In [11]:
class RestPercepton_block(nn.Module):
    """
    Base class for percepton block with residual connection (no pre-activation and BN before conv)

    in_channels - num. channles into block
    out_channels - num. concatenated channles out from block
    conv_size_in - list of num. of channels into 3 and 5 conv (respectively) [conv 1x1 in is always size of in_channels]
    conv_size_out - list of num. of channels going out from 1,3,5 conv (-||-)
    stride, padding - list with stride and padding values for 1, 3, 5 conv respectively
    change_depth_pool - change depth for pooling. By default "False"(no change), if used must be int (out depth from pool section)
    """


    #TODO add batch normalization and activation functions after conv
    def __init__(self, in_channels, out_channels,
                conv_size_in:list, conv_size_out:list,
                stride:list=[1,1,1], padding:list=[0, 1, 2],
                change_depth_pool=False):

        # checking if dim are correct
        if(change_depth_pool):
            if(out_channels != sum(conv_size_out) + change_depth_pool):
                raise ValueError(
                    "Sum of out channels of the block must be equal to sum of out channels of convs inside the block"
                )
        elif(not change_depth_pool):
            if(out_channels != sum(conv_size_out) + in_channels):
                raise ValueError(
                    "Sum of out channels of the block must be equal to sum of out channels of convs inside the block"
                )

        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels


        # conv 1x1
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, conv_size_out[0], kernel_size=1, stride=stride[0], padding=padding[0]),
            nn.BatchNorm2d(conv_size_out[0]),
            nn.ReLU()
        )

        # conv 3x3
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels, conv_size_in[0], kernel_size=1, padding=0, stride=1), # change depth so it matches 3x3 conv in size
            nn.Conv2d(conv_size_in[0], conv_size_out[1], kernel_size=3, stride=stride[1], padding=padding[1]),
            nn.BatchNorm2d(conv_size_out[1]),
            nn.ReLU()
        )

       # conv 5x5
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels, conv_size_in[1], kernel_size=1, padding=0, stride=1),
            nn.Conv2d(conv_size_in[1], conv_size_out[2], kernel_size=5, stride=stride[2], padding=padding[2]),
            nn.BatchNorm2d(conv_size_out[2]),
            nn.ReLU()
        )

        # max pool 3x3
        if(change_depth_pool):
            self.pool = nn.Sequential(
                nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
                nn.Conv2d(in_channels, change_depth_pool, kernel_size=1, padding=0, stride=1),
                nn.BatchNorm2d(change_depth_pool),
                nn.ReLU()
            )

        else:
            self.pool = nn.Sequential(
                nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(self.in_channels),
                nn.ReLU()
            )

        # changer depth of rest connection
        if(in_channels != out_channels):
            self.RestConv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, stride=1)

    def forward(self, x):

        conv1 = self.conv1(x)
        conv3 = self.conv3(x)
        conv5 = self.conv5(x)
        pool = self.pool(x)

        if(self.in_channels != self.out_channels):

            residual = self.RestConv(x)
        else:
            residual = x


        return(torch.cat([conv1, conv3, conv5, pool], dim=1) + residual)

In [12]:
class ImageClassifierNetwork(nn.Module):

    def __init__(self):
        super().__init__()
        self.loss = nn.CrossEntropyLoss(weight=weights)

    def train_step(self, input, labels):
        input = input.to(device=device)
        labels = labels.to(device=device)

        preds = self.forward(input)
        loss = self.loss(preds, labels)

        return loss

    def val_step(self, input, labels):

        input = input.to(device=device)
        labels = labels.to(device=device)

        preds = self.forward(input)
        loss = self.loss(preds, labels)
        accuracy = self._accuracy(preds, labels)

        return {"loss":loss.detach(), "accuracy":accuracy}


    def val_epoch_end(self, preformance_measurement_data):

        accuracy = [x["accuracy"].cpu().numpy() for x in preformance_measurement_data]
        avg_accuracy = np.mean(accuracy)

        loss = [x["loss"].cpu().numpy() for x in preformance_measurement_data]
        avg_loss = np.mean(loss)

        return avg_loss, avg_accuracy


    def _accuracy(self, preds, labels):

        batch_size = len(preds)

        pred_indices = torch.argmax(preds, dim=1)
        return torch.tensor(torch.sum(pred_indices == labels).item() / batch_size)


In [13]:
class RestGoogleNet_Clasificator(ImageClassifierNetwork):
    """
    Input - in_channels x 96x96
    """
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.Conv1 = nn.Conv2d(in_channels, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.Conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)

        self.Intercepton1 = RestPercepton_block(in_channels=64, out_channels=168, conv_size_in=[32,32], conv_size_out=[24, 64, 64],change_depth_pool=16)
        self.Intercepton2 = RestPercepton_block(in_channels=168, out_channels=318, conv_size_in=[72,72], conv_size_out=[60, 118, 108],change_depth_pool= 32)
        self.Intercepton3 = RestPercepton_block(in_channels=318, out_channels=414, conv_size_in=[128,96], conv_size_out=[78, 154, 134],change_depth_pool=48)
        self.Intercepton4 = RestPercepton_block(in_channels=414, out_channels=540, conv_size_in=[140,124], conv_size_out=[112, 190, 174],change_depth_pool=64)
        self.Intercepton5 = RestPercepton_block(in_channels=540, out_channels=712, conv_size_in=[192,150], conv_size_out=[152, 256, 224],change_depth_pool=80)

        self.MaxPool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.AvgPool_4x4 = nn.AvgPool2d(kernel_size=4, stride=1)

        self.Dropout = nn.Dropout(0.5)
        self.Linear = nn.Linear(712, num_classes)


        self.BN1 = nn.BatchNorm2d(32)
        self.BN2 = nn.BatchNorm2d(64)

        self.activation = nn.ReLU()


    def forward(self, x):

        x = self.Conv1(x)
        x = self.BN1(x)
        x = self.activation(x)
        x = self.MaxPool_2x2(x) # 128x128

        x = self.Conv2(x)
        x = self.BN2(x)
        x = self.activation(x)
        x = self.MaxPool_2x2(x) # 64x64

        x = self.Intercepton1(x)
        x = self.MaxPool_2x2(x) # 32x32

        x = self.Intercepton2(x)
        x = self.MaxPool_2x2(x) # 16x16
        x = self.Intercepton3(x)
        x = self.MaxPool_2x2(x) # 8x8

        x = self.Intercepton4(x)
        x = self.MaxPool_2x2(x) # 4x4
        x = self.Intercepton5(x)
        x = self.AvgPool_4x4(x)

        x = x.view(x.size(0), -1)

        x = self.Dropout(x)
        x = self.Linear(x)

        return x

In [14]:
num_classes = 20
model = RestGoogleNet_Clasificator(in_channels=3, num_classes=num_classes)
model.to(device)
# model.load_state_dict(torch.load("5_epch_googleNet_SPECIES_V1.pt"))

RestGoogleNet_Clasificator(
  (loss): CrossEntropyLoss()
  (Conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (Conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (Intercepton1): RestPercepton_block(
    (conv1): Sequential(
      (0): Conv2d(64, 24, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (conv3): Sequential(
      (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()
    )
    (conv5): Sequential(
      (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ReLU()


In [15]:
params_sum = 0
for params in model.parameters():
    params_sum+=params.view(-1).size(0)
params_sum

4432058

In [16]:
epoch = 5

optimizer = torch.optim.Adam(model.parameters(), lr=1.5e-3, weight_decay=1e-4)
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [17]:
def eval(val_loader, model):
        history = []
        for input, label in val_loader:
            history += [model.val_step(input, label)]

        loss_val, acc_val = model.val_epoch_end(history)

        return loss_val, acc_val

In [18]:
torch.manual_seed(seed=torch.initial_seed())

transform = transforms.RandomChoice(
    [
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(degrees=(20, 60)),
    transforms.RandomPerspective(distortion_scale=.35),
    ]
)

In [19]:
def train(train_dataloader, epoch):
    model.train(mode=True)

    epoch_loss_T = []
    batch_loss_T = []
    epoch_loss_V = []
    epoch_acc_V = []


    for ep in range(epoch):

        for input, label in train_dataloader:

            transformed_input = transform(input)
            loss = model.train_step(transformed_input, label)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            lr_scheduler.step()

            batch_loss_T += [loss.item()]

        epoch_loss_T += [np.mean(batch_loss_T)]
        batch_loss_T = []

        loss_val, acc_val = eval(val_loader, model)
        epoch_loss_V += [loss_val]
        epoch_acc_V += [acc_val]

        print(f'Epoch {ep+1}; TLoss: {epoch_loss_T[ep]}; Vloss: {loss_val}; Acc: {acc_val}')


    model.eval()
    return epoch_loss_T, epoch_loss_V, epoch_acc_V

In [20]:
# loss_T, loss_V, acc_V = train(train_loader, epoch)

In [21]:
name = f'{epoch}_epch_googleNet_SPECIES_V1.pt'
torch.save(model.state_dict(), name)

In [22]:
with open('species_dict.pkl', 'rb') as file:
    species_dict = pickle.load(file)

In [23]:
def get_pred(image):
    softmax = nn.Softmax()
    image = transform_raw_2(image)

    pred = model(image)
    pred = softmax(pred)
    class_idx = torch.argmax(pred, dim=1)

    predicted_species = [key for key, value in species_dict.items() if value == class_idx]
    probability = pred[class_idx]

    return predicted_species, probability


SyntaxError: invalid syntax (2937492945.py, line 2)