* # Обработка текстов

Предполагается зафиксированными следующие гиперпараметры модели:
- размер батча: 64;
- длина обрезанного токенизированного предложения: 200;
- размерность результата эмбеддинга для 1 предложения: 32;
- размерность hidden в рекуррентной нейронной сети;
- выход модели (после линейного слоя): 10;

In [1]:
import torch
class RNNClassifier(torch.nn.Module):
    def __init__(
        self, embedding_dim, hidden_dim, output_size, vocab,
        rec_layer=torch.nn.LSTM, dropout=None, **kwargs
    ):
        super().__init__()

        self.dropout = dropout
        
        self.vocab = vocab
        self.hidden_dim = hidden_dim
        self.output_size = output_size
        self.embedding_dim = embedding_dim

        self.word_embeddings = torch.nn.Embedding(num_embeddings=len(self.vocab.stoi), embedding_dim=self.embedding_dim, padding_idx=0)
        if dropout is not None:
            self.rnn = rec_layer(self.embedding_dim, self.hidden_dim, dropout=self.dropout, **kwargs)
        else:
            self.rnn = rec_layer(self.embedding_dim, self.hidden_dim, **kwargs)
        self.Linear = torch.nn.Linear(self.hidden_dim, output_size)
    
    def forward(self, tokens, tokens_lens):
        """
        :param torch.tensor(dtype=torch.long) tokens: Batch of texts represented with tokens.
        :param torch.tensor(dtype=torch.long) tokens_lens: Number of non-padding tokens for each object in batch.
        :return torch.tensor(dtype=torch.long): Vector representation for each sequence in batch
        """
                                                                                ### tokens.shape = (200, 64) Число слов в токенизированном предложении на число предложений в батче (под предложением подразумевается её токенизированная версия)
                                                                                ### tokens_lens.shape = (64) Число предложений в батче
        out = self.word_embeddings(tokens)                                      ### out.shape = (200, 64, 32), 32 - это размерность вектора эмбеддинга, которым мы "кодируем" каждый токен
        
        out = self.rnn(out)[0]                                                  ### self.rnn(out) = (output, (h, c)), output.shape = (200, 64, 128), h.shape = (1, 64, 128), c.shape = (1, 64, 128)
                                                                                ### Откуда взялось 128: мы подаём эмбеддинговые представления в ячейку LSTM, в которой перед агрегацией состояния ячейки с входными данными происходит преобразование линейного слоя, и эмбеддинговые представления токенов из 32-размерных становятся 128-размерными

        out = out[tokens_lens - 1, np.arange(out.shape[1]), :]                  ### out.shape = (64, 128)
        res = self.Linear(out)                                                  ### res.shape = (64, 10)
        return res

In [2]:
def train_epoch(dataloader, model, loss_fn, optimizer, device):
    model.train()
    for idx, data in enumerate(dataloader):
        optimizer.zero_grad()
        # 1. Take data from batch
        tokens, tokens_lens, ratings = data['tokens'].to(device), data['tokens_lens'].to(device), data['ratings'].to(device)
        # 2. Perform forward pass
        model_out = model(tokens, tokens_lens).to(device)
        # 3. Evaluate loss
        loss = loss_fn(model_out, ratings)
        loss.backward()
        # 4. Make optimizer step
        optimizer.step()
    
def evaluate(dataloader, model, loss_fn, device):
    model.eval()
    
    total_loss = 0.0
    total_accuracy = 0.0
    with torch.no_grad():
        for idx, data in enumerate(dataloader):
            # 1. Take data from batch
            tokens, tokens_lens, ratings = data['tokens'].to(device), data['tokens_lens'].to(device), data['ratings'].to(device)
            # 2. Perform forward pass
            model_out = model(tokens, tokens_lens).to(device)
            # 3. Evaluate loss
            loss = loss_fn(model_out, ratings)
            total_loss += float(loss.detach())
            # 4. Evaluate accuracy
            total_accuracy += torch.sum(1*(torch.argmax(model_out, dim=1) == ratings))
        
    return total_loss / len(dataloader.dataset), total_accuracy / len(dataloader.dataset)

def train(train_loader, test_loader, model, loss_fn, optimizer, device, num_epochs):
    test_losses = []
    train_losses = []
    test_accuracies = []
    train_accuracies = []
    for epoch in range(num_epochs):
        train_epoch(train_loader, model, loss_fn, optimizer, device)
        
        train_loss, train_acc = evaluate(train_loader, model, loss_fn, device)
        train_accuracies.append(train_acc)
        train_losses.append(train_loss)
        
        test_loss, test_acc = evaluate(test_loader, model, loss_fn, device)
        test_accuracies.append(test_acc)
        test_losses.append(test_loss)
        
        print(
            'Epoch: {0:d}/{1:d}. Loss (Train/Test): {2:.3f}/{3:.3f}. Accuracy (Train/Test): {4:.3f}/{5:.3f}'.format(
                epoch + 1, num_epochs, train_losses[-1], test_losses[-1], train_accuracies[-1], test_accuracies[-1]
            )
        )
    return train_losses, train_accuracies, test_losses, test_accuracies

# Реализация дропаута по статье Гала и Гарамани. Variational Dropout.

In [3]:
def init_h0_c0(num_objects, hidden_size, some_existing_tensor):
    """
    return h0 and c0, use some_existing_tensor.new_zeros() to gen them
    h0 shape: num_objects x hidden_size
    c0 shape: num_objects x hidden_size
    """
    return (some_existing_tensor.new_zeros(size=(num_objects, hidden_size)), 
            some_existing_tensor.new_zeros(size=(num_objects, hidden_size)))

def gen_dropout_mask(input_size, hidden_size, is_training, p, some_existing_tensor):
    """
    is_training: if True, gen masks from Bernoulli
                 if False, gen masks consisting of (1-p)
    
    return dropout masks of size input_size, hidden_size if p is not None
    return one masks if p is None
    """
    if p is None:
        return (some_existing_tensor.new_ones(hidden_size),
                some_existing_tensor.new_ones(input_size))
    else:
        if is_training:
            return (torch.bernoulli(some_existing_tensor.new_ones(hidden_size) * (1 - p)),
                    torch.bernoulli(some_existing_tensor.new_ones(input_size) * (1 - p)))
        else:
            return (some_existing_tensor.new_ones(hidden_size) * (1 - p),
                    some_existing_tensor.new_ones(input_size) * (1 - p))

In [4]:
class RNNLayer(torch.nn.Module):
    def __init__(self, input_size, hidden_size, dropout=None):
        super().__init__()

        self.dropout = dropout
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.rnn_cell = torch.nn.LSTMCell(self.input_size, self.hidden_size)
        
    def forward(self, x):
                                                                                ### x.shape = (200, 64, 32)
        h_0, c_0 = init_h0_c0(num_objects=x.shape[1], 
                              hidden_size=self.hidden_size, 
                              some_existing_tensor=x)                           ### h_0.shape = c_0.shape = (64, 128)
        
        m_h, m_x = gen_dropout_mask(input_size=self.input_size, 
                                    hidden_size=self.hidden_size, 
                                    is_training=self.training, 
                                    p=self.dropout, 
                                    some_existing_tensor=x)                     ### m_h.shape = (64, 128), m_x.shape = (64, 32)
    
        cell_output = []
        h, c = h_0, c_0
        for t in range(x.shape[0]):
            h, c = self.rnn_cell(x[t] * m_x, (h * m_h, c))                      ### h.shape = c.shape = (64, 128)
            cell_output.append(h)

        output = torch.stack(cell_output, dim=0)                                ### output.shape = (200, 64, 128)
        return output, (h, c)

# Эффективная реализация дропаута по статье Гала и Гарамани.

In [5]:
class FastRNNLayer(torch.nn.Module):
    def __init__(self, input_size, hidden_size, dropout=None, layers_dropout=0.0, num_layers=1):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size

        self.num_layers = num_layers

        self.dropout = dropout
        self.layers_dropout = layers_dropout
        self.module = torch.nn.LSTM(input_size, hidden_size, dropout=layers_dropout, num_layers=num_layers)

        self.layer_names = []
        for layer_n in range(self.num_layers):
            self.layer_names += [f'weight_hh_l{layer_n}', f'weight_ih_l{layer_n}']

        for layer in self.layer_names:
            # Get torch.nn.Parameter with weights from torch.nn.LSTM instance
            w = getattr(self.module, layer)

            # Remove it from model
            delattr(self.module, layer)

            # And create new torch.nn.Parameter with the same data but different name
            self.register_parameter(f'{layer}_raw', torch.nn.Parameter(w.data))

    def _setweights(self, x):
        """
            Apply dropout to the raw weights.
        """
        for layer in self.layer_names:
            # Get torch.nn.Parameter with weights
            raw_w = getattr(self, f'{layer}_raw')

            h_m, x_m = gen_dropout_mask(input_size=self.input_size, 
                                        hidden_size=self.hidden_size, 
                                        is_training=self.training,
                                        p=self.dropout, some_existing_tensor=x)

            # Apply dropout mask
            if raw_w.shape[1] == h_m.shape[0]:
                masked_raw_w = raw_w * h_m
            else:
                masked_raw_w = raw_w * x_m

            # Set modified weights in its place
            setattr(self.module, layer, masked_raw_w)

    def forward(self, x, h_c=None):
        """
        :param x: tensor containing the features of the input sequence.
        :param Optional[Tuple[torch.tensor, torch.tensor]] h_c: initial hidden state and initial cell state
        """
        with warnings.catch_warnings():
            # To avoid the warning that comes because the weights aren't flattened.
            warnings.simplefilter("ignore")

            # Set new weights of self.module and call its forward
            # Pass h_c with x if it is not None. Otherwise pass only x
            # YOUR CODE HERE
            self._setweights(x)
            if not h_c is None:
                return self.module.forward(x, h_c)
            else:
                return self.module.forward(x)
            
    def reset(self):
        if hasattr(self.module, 'reset'):
            self.module.reset()

# Реализация дропаута по статье Семениуты и др.

In [6]:
class HandmadeLSTM(torch.nn.Module):
    def __init__(self, input_size, hidden_size, dropout=None):
        super().__init__()
        
        self.dropout = dropout
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.input_weights = torch.nn.Linear(input_size, 4 * hidden_size)
        self.hidden_weights = torch.nn.Linear(hidden_size, 4 * hidden_size)
        
        self.reset_params()

    def reset_params(self):
        """
        Initialization as in Pytorch. 
        Do not forget to call this method!
        https://pytorch.org/docs/stable/_modules/torch/nn/modules/rnn.html#LSTM
        """
        stdv = 1.0 / np.sqrt(self.hidden_size)
        for weight in self.parameters():
            torch.nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, x):
                                                                                ### x.shape = (200, 64, 32)
        h_0, c_0 = init_h0_c0(num_objects=x.shape[1], 
                              hidden_size=self.hidden_size, 
                              some_existing_tensor=x)                           ### h_0.shape = c_0.shape = (64, 128)

        m_h, m_x = gen_dropout_mask(input_size=self.input_size, 
                                    hidden_size=self.hidden_size, 
                                    is_training=self.training, 
                                    p=self.dropout, 
                                    some_existing_tensor=x)                     ### m_h.shape = (64, 128), m_x.shape = (64, 32)
        
        # Implement recurrent logic to mimic torch.nn.LSTM
        cell_output = []
        h, c = h_0, c_0
        for t in range(x.shape[0]):
            iofg = self.hidden_weights(h) + self.input_weights(x[t] * m_x)      ### iofg.shape = (64, 4*128)
            iof = torch.sigmoid(iofg[:, :3*self.hidden_size])                   ### iof.shape = (64, 3*128)
            i = iof[:, :self.hidden_size]                                       ### i.shape = (64, 128)
            o = iof[:, self.hidden_size:2*self.hidden_size]                     ### o.shape = (64, 128)
            f = iof[:, 2*self.hidden_size:3*self.hidden_size]                   ### f.shape = (64, 128)
            g = torch.tanh(iofg[:, 3*self.hidden_size:])                        ### g.shape = (64, 128)
            c = f * c + i * g * m_h                                             ### c.shape = (64, 128)
            h = o * torch.tanh(c)                                               ### h.shape = (64, 128)
            
            cell_output.append(h)               
        output = torch.stack(cell_output, dim=0)                                ### output.shape = (200, 64, 128)
        return output, (h, c)

# Реализация Zoneout

In [7]:
class Zoneout(torch.nn.Module):
    def __init__(self, input_size, hidden_size, dropout=None):
        super().__init__()
        
        self.dropout = dropout
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.input_weights = torch.nn.Linear(input_size, 4 * hidden_size)
        self.hidden_weights = torch.nn.Linear(hidden_size, 4 * hidden_size)
        
        self.reset_params()

    def reset_params(self):
        """
        Initialization as in Pytorch. 
        Do not forget to call this method!
        https://pytorch.org/docs/stable/_modules/torch/nn/modules/rnn.html#LSTM
        """
        stdv = 1.0 / np.sqrt(self.hidden_size)
        for weight in self.parameters():
            torch.nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, x):
                                                                                ### x.shape = (200, 64, 32)
        h_0, c_0 = init_h0_c0(num_objects=x.shape[1], 
                              hidden_size=self.hidden_size, 
                              some_existing_tensor=x)                           ### h_0.shape = c_0.shape = (64, 128)

        m_x = gen_dropout_mask(input_size=self.input_size, 
                                    hidden_size=self.hidden_size, 
                                    is_training=self.training, 
                                    p=self.dropout, 
                                    some_existing_tensor=x)[1]                  ### m_h.shape = (64, 128), m_x.shape = (64, 32)
        
        # Implement recurrent logic to mimic torch.nn.LSTM
        cell_output = []
        h, c = h_0, c_0
        for t in range(x.shape[0]):

            m_h = gen_dropout_mask(input_size=self.input_size, 
                                    hidden_size=self.hidden_size, 
                                    is_training=self.training, 
                                    p=self.dropout, 
                                    some_existing_tensor=x)[0]                  ### m_h.shape = (64, 128), m_x.shape = (64, 32)
            iofg = self.hidden_weights(h) + self.input_weights(x[t] * m_x)      ### iofg.shape = (64, 4*128)
            iof = torch.sigmoid(iofg[:, :3*self.hidden_size])                   ### iof.shape = (64, 3*128)
            i = iof[:, :self.hidden_size]                                       ### i.shape = (64, 128)
            o = iof[:, self.hidden_size:2*self.hidden_size]                     ### o.shape = (64, 128)
            f = iof[:, 2*self.hidden_size:3*self.hidden_size]                   ### f.shape = (64, 128)
            g = torch.tanh(iofg[:, 3*self.hidden_size:])                        ### g.shape = (64, 128)
            c = f * c + i * g                                                   ### c.shape = (64, 128)
            h = m_h * (o * torch.tanh(c)) + (1 - m_h) * h                       ### h.shape = (64, 128)
            
            cell_output.append(h)               
        output = torch.stack(cell_output, dim=0)                                ### output.shape = (200, 64, 128)
        return output, (h, c)

* # Сегментация изображений

# Модуль аугментации

In [8]:
class Flipping:
    def __init__(self, p=1):
        self.p = p

    def __call__(self, img, mask=None):
        p_sample = np.random.rand()
        if p_sample < self.p:
            flip_matrix = np.fromfunction(lambda i, j: i + j == img.shape[2] - 1, (img.shape[2], img.shape[2])).astype(int)
            return torch.from_numpy(np.dot(img, flip_matrix)), torch.from_numpy(np.dot(mask, flip_matrix))
        return img.double(), mask.double()
    
    
class Cutting:
    def __init__(self, p=1):
        self.p = p

    def __call__(self, img, mask=None, size=None):
        p_sample = np.random.rand()
        if p_sample < self.p:
            if size is None:
                a = np.random.randint(0, img.shape[1] // 4)
                b = np.random.randint(img.shape[1] * 3 // 4, img.shape[1])
                c = np.random.randint(0, img.shape[2] // 4)
                d = np.random.randint(img.shape[2] * 3 // 4, img.shape[2])
            else:
                a, b, c, d = size
            cut_matrix = np.ones((img.shape[1], img.shape[2]))
            cut_matrix *= np.fromfunction(lambda i, j: (i - a)*(i - b) <= 0, (img.shape[1], img.shape[2])).astype(int)
            cut_matrix *= np.fromfunction(lambda i, j: (j - c)*(j - d) <= 0, (img.shape[1], img.shape[2])).astype(int)
            return img * cut_matrix, mask * cut_matrix
        return img.double(), mask.double()
    
    
class Brighting:
    def __init__(self, p=1):
        self.p = p

    def __call__(self, img, mask=None, delta=None):
        p_sample = np.random.rand()
        if p_sample < self.p:
            if delta == None:
                delta = np.random.randn()
            img += delta
            img[img < 0] = 0
            img[img > 1] = 1    
            return img, mask
        return img.double(), mask.double()
    
    
class Screening:
    def __init__(self, p=1):
        self.p = p

    def __call__(self, img, mask=None, screen=None):
        p_sample = np.random.rand()
        if p_sample < self.p:
            screen = screen.type(img.dtype)
            the_mask = [mask[0] + mask[1] + mask[2] == 0]
            img[0][the_mask] = screen[0][the_mask]
            img[1][the_mask] = screen[1][the_mask]
            img[2][the_mask] = screen[2][the_mask]
            return img, mask
        return img.double(), mask.double()
    
    
def sequential_transforms(img, mask, propabilities=(1, 1, 1, 1), params=(None, None, None)):
    img = torchvision.transforms.ToTensor()(img)
    img = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
    mask = torchvision.transforms.ToTensor()(mask)
    img, mask = Flipping(propabilities[0])(img, mask)
    if params[2] is None:
        img0 = copy.deepcopy(img)
    else:
        img0 = params[2]
    img, mask = Brighting(propabilities[1])(img, mask, params[0])
    img, mask = Cutting(propabilities[2])(img, mask, params[1])
    img, mask = Screening(propabilities[3])(img, mask, screen=img0)
    return img, mask

# Функция потерь

In [9]:
class DiceLoss(torch.nn.Module):
    def __init__(self, eps=1e-7, reduction=None, with_logits=True):
        """
        Arguments
        ---------
        eps : float
            eps in denominator
        reduction : Optional[str] (None, 'mean' or 'sum')
            specifies the reduction to apply to the output:
            
            None: no reduction will be applied
            'mean': the sum of the output will be divided by the number of batches in the output
            'sum':  the output will be summed. 
        with_logits : bool
            If True, use additional sigmoid for inputs
        """
        super().__init__()
        self.eps = eps
        self.reduction = reduction
        self.with_logits = with_logits
        
    def forward(self, logits, true_labels):
        true_labels = true_labels.float()
        
        if self.with_logits:
            logits = torch.sigmoid(logits)
        
        # your code here
        losses = 1 - 2 * torch.sum(logits * true_labels, axis=(1, 2)) / torch.sum(logits + true_labels + self.eps, axis=(1, 2))
        
        if self.reduction == 'sum':
            loss_value = torch.sum(losses)
        elif self.reduction == 'mean':
            loss_value = torch.mean(losses)
        elif self.reduction is None:
            loss_value = losses
        return loss_value

In [10]:
class Combination_Loss(torch.nn.Module):
    def __init__(self, alpha=0.5):
        super().__init__()
        self.alpha = alpha
        self.first_criterion = torch.nn.BCEWithLogitsLoss()
        self.second_criterion = DiceLoss(reduction='mean')
        
    def forward(self, net_out, target):
        return self.alpha * self.first_criterion(net_out, target) + (1 - self.alpha) * self.second_criterion(net_out, target)

# Unet

In [11]:
class VGG13Encoder(torch.nn.Module):
    def __init__(self, num_blocks, pretrained=True):
        super().__init__()
        self.num_blocks = num_blocks
        self.blocks = nn.ModuleList()
        # Obtaining pretrained VGG model from torchvision.models and
        # copying all layers except for max pooling.
        feature_extractor = vgg13(pretrained=pretrained).features
        for i in range(self.num_blocks):
            self.blocks.append(
                torch.nn.Sequential(*[feature_extractor[j]
                                      for j in range(i * 5, i * 5 + 4)]))

    def forward(self, x):
        activations = []
        for i in range(self.num_blocks):
            x = self.blocks[i](x)
            activations.append(x)
            if i != self.num_blocks - 1:
                x = torch.functional.F.max_pool2d(x, kernel_size=2, stride=2)
        return activations

In [12]:
class DecoderBlock(torch.nn.Module):
    def __init__(self, out_channels):
        super().__init__()

        self.upconv = torch.nn.Conv2d(
            in_channels=out_channels * 2, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1
        )
        self.conv1 = torch.nn.Conv2d(
            in_channels=out_channels * 2, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1
        )
        self.conv2 = torch.nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1
        )
        self.relu = nn.ReLU()
        
        
    def forward(self, down, left):
        x = torch.nn.functional.interpolate(down, scale_factor=2)
        x = self.upconv(x)
        x = self.relu(self.conv1(torch.cat([left, x], 1)))
        x = self.relu(self.conv2(x))
        return x

In [13]:
class Decoder(torch.nn.Module):
    def __init__(self, num_filters, num_blocks):
        super().__init__()

        for i in range(num_blocks):
            self.add_module(f'block{num_blocks - i}', DecoderBlock(num_filters * 2**i))

    def forward(self, acts):
        up = acts[-1]
        for i, left in enumerate(acts[-2::-1]):
            up = self.__getattr__(f'block{i + 1}')(up, left)
        return up

In [14]:
class UNet(torch.nn.Module):
    def __init__(self, num_classes=1, num_filters=64, num_blocks=4):
        super().__init__()
        self.encoder = VGG13Encoder(num_blocks=num_blocks)
        self.decoder = Decoder(num_filters=64, num_blocks=num_blocks - 1)
        self.final = torch.nn.Conv2d(
            in_channels=num_filters, out_channels=num_classes, kernel_size=1
        )

    def forward(self, x):
        acts = self.encoder(x)
        x = self.decoder(acts)
        x = self.final(x)
        return x

In [15]:
def iou(outputs, labels, inline=False):
    outputs = torch.nn.Sigmoid()(outputs)
    outputs[outputs < 0.5] = 0
    outputs[outputs >= 0.5] = 1
    intersection = (outputs * labels).float().sum((1, 2))
    union = (outputs + labels - outputs * labels).float().sum((1, 2))
    iou = (intersection + 1e-6) / (union + 1e-6)
    if inline:
        print('    acc on batch:', float(iou.mean()))
    return iou.mean()

In [16]:
def evaluate_net(net, testloader, criterion, val_criterion=iou, device='cpu'):
    net = net.eval()

    loss = 0.
    correct = 0.
    total1 = 0.
    total2 = 0
    
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            labels = labels[:, 0, :, :]
            images = images.to(device)
            
            outputs = net(images.float()).to('cpu').squeeze(1)
            total1 += labels.size(0)
            total2 += 1
            loss += float(criterion(outputs, labels).detach())
            correct += float(val_criterion(outputs, labels).detach())
    
    mean_loss = loss / total1
    metric = correct / total2
    
    return mean_loss, metric

In [17]:
def train(model):
    times = {}
    losses = {}
    accs = {}
    acc_test = []
    acc_train = []
    loss_test = []
    loss_train = []
    for epoch in range(epochs):
        print('epoch:', epoch)
        times[epoch] = []
        losses[epoch] = []
        accs[epoch] = []
        model = model.to(device)
        model = model.train()
        k = 0
        buf_time = time.time() 
        for data, target in train_data_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            net_out = model(data.float())
            target = target[:, 0, :, :]

            net_out = net_out.squeeze(1)
          
            loss = criterion(net_out, target)
            loss.backward()
            optimizer.step()
            print('  loss on batch ', k, ':', loss.item())
            k += 1
            end_time = time.time()
            times[epoch] += [end_time - buf_time]
            
            losses[epoch] += [float(loss.item())]
            accs[epoch] += [float(val_criterion(net_out, target, True))] ##
            del loss 
        scheduler.step()

        info_test = evaluate_net(model, test_data_loader, criterion, iou, 'cuda:0')
        info_train = evaluate_net(model, train_data_loader, criterion, iou, 'cuda:0')
        print('epoch result:')
        print('...train:')
        print('......accy:', info_train[1])
        print('......loss:', info_train[0])
        print('...test:')
        print('......accy:', info_test[1])
        print('......loss:', info_test[0])
        acc_test += [info_test[1]]
        acc_train += [info_train[1]]
        loss_test += [info_test[0]]
        loss_train += [info_train[0]]
    torch.save(model, 'path_to_model.pth')
    return accs, losses, times, acc_test, acc_train, loss_test, loss_train

# LinkNet

In [18]:
class DecoderBlock2(torch.nn.Module):
    def __init__(self, out_channels):
        super().__init__()

        self.upconv = torch.nn.Conv2d(
            in_channels=out_channels * 2, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1
        )
        self.conv1 = torch.nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1
        )
        self.conv2 = torch.nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1
        )
        self.relu = nn.ReLU()
        
        
    def forward(self, down, left):
        x = torch.nn.functional.interpolate(down, scale_factor=2)
        x = self.upconv(x)
        x = self.relu(self.conv1(left + x))
        x = self.relu(self.conv2(x))
        return x

In [19]:
class Decoder2(torch.nn.Module):
    def __init__(self, num_filters, num_blocks):
        super().__init__()

        for i in range(num_blocks):
            self.add_module(f'block{num_blocks - i}', DecoderBlock2(num_filters * 2**i))

    def forward(self, acts):
        up = acts[-1]
        for i, left in enumerate(acts[-2::-1]):
            up = self.__getattr__(f'block{i + 1}')(up, left)
        return up

In [20]:
class LinkNet(torch.nn.Module):
    def __init__(self, num_classes=1, num_filters=64, num_blocks=4):
        super().__init__()
        self.encoder = VGG13Encoder(num_blocks=num_blocks)
        self.decoder = Decoder2(num_filters=64, num_blocks=num_blocks - 1)
        self.final = torch.nn.Conv2d(
            in_channels=num_filters, out_channels=num_classes, kernel_size=1
        )

    def forward(self, x):
        acts = self.encoder(x)
        x = self.decoder(acts)
        x = self.final(x)
        return x