In [1]:
import torch
import timm
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import OrderedDict
import os
import random
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from glob import glob
import albumentations as A
from PIL import Image
import cv2

  from pandas.core import (


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_workers = 4
batch_size = 8
img_shape = [512, 512]
random_state = 69

In [3]:
def set_random_seed(seed: int = 2222, deterministic: bool = False):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = deterministic
    
set_random_seed(random_state)

In [4]:
def read_image(filepath):
    image = Image.open(filepath)
    image = np.array(image)
    image = cv2.resize(image, (512, 512))
    return image


def read_mask(filepath):
    image = Image.open(filepath).convert('L')
    image = np.array(image).astype(np.float64)
    image /= image.max()
    image = image.astype(np.uint8)
    image = cv2.resize(image, (512, 512))
    return image


In [5]:
labels, fns = [], glob('original_images')

with open("labels.txt" , "r") as fin:
    f = fin.readlines()
    for ff in f:
        labels.append(int(ff))
labels

[1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0]

In [6]:
import torch.nn as nn
ALPHA = 0.8
GAMMA = 2

class FocalLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(FocalLoss, self).__init__()

    def forward(self, outputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none') # important to add reduction='none' to keep per-batch-item loss
        pt = torch.exp(-ce_loss)
        focal_loss = (alpha * (1-pt)**gamma * ce_loss).mean() # mean over the batch
        return focal_loss


In [7]:
transforms_train = A.Compose([

#     A.OneOf([
#         A.OpticalDistortion(distort_limit=1.0),
#         A.GridDistortion(num_steps=5, distort_limit=1.),
#         A.ElasticTransform(alpha=3),
#     ], p=AUG_PROB),

#     A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=AUG_PROB),
    A.Resize(img_shape[0], img_shape[1]),
#     A.CoarseDropout(max_holes=16, max_height=16, max_width=16, min_holes=1, min_height=2, min_width=2, p=AUG_PROB),    
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transforms_val = A.Compose([
    A.Resize(img_shape[0], img_shape[1]),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_for_masks = A.Compose([
    A.Resize(img_shape[0], img_shape[1])
])


In [8]:
def testing_model(model, image):
    #takes image and returns mask in 0-1 format
    model.eval()
    image = transforms_val(image=image[:, :, :])['image']
    image = image.transpose(2, 0, 1)
    image = torch.tensor(image)
    image = image.unsqueeze(dim=0)
    image = image.to(device)
    with torch.no_grad():
        proba = model(image[:, :, :, :]).detach().cpu().numpy()
        output = proba.round()[0,0,:,:]
    return output, proba[0, 0, ...]


from collections import deque
from tqdm.auto import tqdm

def bfs(image):
    n,m = len(image), len(image[0])
    graph = {i: [] for i in range(n*m)}
    
    # creating graph
    for row in range(n):
        for col in range(m):
            if image[row][col] == 1:
                u = row * m + col
                adj = []
                if row != 0: adj.append([row-1, col])
                if col != 0: adj.append([row, col-1])
                if row != n-1: adj.append([row+1, col])
                if col != m-1: adj.append([row, col+1])
                
                for point in adj:
                    y, x = point
                    if image[y][x] == 1:
                        v = y*m + x
                        graph[u].append(v)
                del adj, u
                        
    # doing bfs
    used, comps = [False for i in range(n*m)], []
    for cell in range(n*m):
        row = cell // m
        col = cell % m
        if image[row][col] == 0:
            continue
        if used[cell]:
            continue
            
        visited = []
        q = deque()
        start = cell
        q.append(start)
    
        while q:
            node = q.popleft()
            for v in graph[node]:
                if not used[v]:
                    q.append(v)
                    used[v] = True
                    visited.append(v)
                
        comps.append(visited)
        del visited, q
    return comps 

def inference_take_the_biggest_component(models, images_path, output_path='output_masks', save_npz=True):
    #inference of the model
    counts = []
    masks = {}
    if not Path(output_path).exists():
        os.mkdir(output_path)
    for i, image_path in tqdm(enumerate(images_path), total=len(images_path), position=0):
        image = read_image(image_path)
        super_output = np.zeros((img_shape[0], img_shape[1]))
        for model in models:
            model.eval()
            with torch.no_grad():
                output, proba = testing_model(model, image)
                super_output += (proba / len(models))     
        super_output = super_output.round()
        
        comps = bfs(super_output)
        counts.append(len(comps))
        
        largest_comp, max_size = [], 0
        for comp in comps:
            if max_size < len(comp):
                max_size = len(comp)
                largest_comp = comp
                
        one_comp_image = np.zeros((img_shape[0], img_shape[1]))
        for node in largest_comp:
            row, col = node // (img_shape[1]), node % (img_shape[0])
            one_comp_image[row, col] = 1
        
        normal_mask = []
        for row in range(one_comp_image.shape[0]):
            if one_comp_image[row, :].sum() > 0:
                normal_mask.append(list(one_comp_image[row, :]))
#         one_comp_image *= 255
        normal_mask = np.array(normal_mask)
        super_normal_mask = []
        for col in range(normal_mask.shape[1]):
            if normal_mask[:, col].sum() > 0:
                super_normal_mask.append(list(normal_mask[:, col]))
        super_normal_mask = np.array(super_normal_mask).T
        
        super_puper_mask = np.zeros((max(super_normal_mask.shape[0], super_normal_mask.shape[1]),
                                   max(super_normal_mask.shape[0], super_normal_mask.shape[1])))
        if super_normal_mask.shape[0] <= super_normal_mask.shape[1]:
            super_puper_mask[(super_normal_mask.shape[1] - super_normal_mask.shape[0]) // 2: 
                             (super_normal_mask.shape[1] - super_normal_mask.shape[0]) // 2 + super_normal_mask.shape[0], :] = super_normal_mask
            
        if super_normal_mask.shape[0] > super_normal_mask.shape[1]:
            super_puper_mask[:, (super_normal_mask.shape[0] - super_normal_mask.shape[1]) // 2: 
                             (super_normal_mask.shape[0] - super_normal_mask.shape[1]) // 2 + super_normal_mask.shape[1]] = super_normal_mask
        super_puper_puper = np.zeros((super_puper_mask.shape[0] + 60, super_puper_mask.shape[1] + 60))
        super_puper_puper[30:super_puper_mask.shape[0] + 30, 30:super_puper_mask.shape[1] + 30] = super_puper_mask
        masks[image_path.split('/')[-1][:-4]] = super_output * 255
        super_output *= 255
        
        cv2.imwrite(os.path.join(output_path, image_path.split('/')[-1][:-4] + '.png'), 
                    super_puper_puper * 255)

    if save_npz:
        np.savez('masks.npz', **masks)
    return counts

def inference(models, images_path, output_path='output_masks', save_npz=True):
    #inference of the model
    counts = []
    masks = {}
    if not Path(output_path).exists():
        os.mkdir(output_path)
    for i, image_path in tqdm(enumerate(images_path), total=len(images_path), position=0):
        image = read_image(image_path)
        super_output = np.zeros((img_shape[0], img_shape[1]))
        for model in models:
            model.eval()
            with torch.no_grad():
                output, proba = testing_model(model, image)
                super_output += (proba / len(models))     
        super_output = super_output.round()
        
        comps = bfs(super_output)
        counts.append(len(comps))
        masks[image_path.split('/')[-1][:-4]] = super_output * 255
        super_output *= 255
        cv2.imwrite(os.path.join(output_path, image_path.split('/')[-1][:-4] + '.png'), 
                    super_output)

    if save_npz:
        np.savez('masks.npz', **masks)
    return counts


In [9]:
if not Path('predicted_original_masks').exists():
    os.mkdir('predicted_original_masks')
    
model = torch.load('trained_models/fold 1/model_resnet18_lovasz_05_05_120.pt').to(device)
model.eval()

image_pathes = glob('original_images/*')
inference_take_the_biggest_component([model], image_pathes, output_path='predicted_original_masks')

  0%|          | 0/186 [00:00<?, ?it/s]

[32,
 20,
 6,
 3,
 36,
 47,
 98,
 19,
 5,
 74,
 6,
 10,
 17,
 68,
 12,
 2,
 10,
 27,
 92,
 31,
 59,
 86,
 41,
 2,
 24,
 5,
 9,
 32,
 8,
 81,
 175,
 61,
 99,
 60,
 30,
 2,
 42,
 50,
 61,
 247,
 21,
 48,
 88,
 1,
 88,
 40,
 3,
 10,
 56,
 27,
 35,
 18,
 68,
 19,
 14,
 1,
 13,
 10,
 16,
 100,
 1,
 5,
 1,
 3,
 29,
 60,
 14,
 267,
 8,
 94,
 29,
 9,
 25,
 35,
 88,
 10,
 6,
 28,
 37,
 4,
 6,
 10,
 15,
 19,
 13,
 21,
 23,
 3,
 58,
 14,
 1,
 3,
 11,
 7,
 18,
 4,
 1,
 159,
 22,
 3,
 14,
 32,
 3,
 42,
 9,
 75,
 1,
 2,
 7,
 14,
 44,
 4,
 3,
 19,
 22,
 8,
 4,
 161,
 5,
 20,
 6,
 88,
 10,
 8,
 5,
 9,
 11,
 9,
 37,
 34,
 32,
 9,
 35,
 7,
 34,
 8,
 6,
 63,
 29,
 33,
 15,
 97,
 59,
 14,
 94,
 7,
 24,
 24,
 5,
 121,
 116,
 13,
 47,
 84,
 40,
 126,
 5,
 137,
 3,
 5,
 100,
 42,
 21,
 28,
 5,
 1,
 3,
 114,
 2,
 4,
 25,
 52,
 32,
 20,
 18,
 83,
 10,
 21,
 5,
 141,
 2,
 81,
 14,
 27,
 28,
 46]

In [10]:
# class TypeDataset(Dataset):
#     def __init__(self, images_pathes, labels, transform=None):
#         super().__init__()
#         self.images_pathes = images_pathes
#         self.labels = labels
#         self.transform = transform
        
#     def __len__(self):
#         return len(self.labels)
    
#     def __getitem__(self, index):
#         path = self.images_pathes[index]
#         image = read_image(path) / 255.0
#         label = self.labels[index]
        
#         if self.transform is not None:
#             image = self.transform(image=image)['image']
#         image = torch.FloatTensor(image).view(1, image.shape[0], image.shape[1])
#         return image, label

class TypeDataset(Dataset):
    def __init__(self, images_pathes, labels, transform=None):
        super().__init__()
        self.images_pathes = images_pathes
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, index):
        path = self.images_pathes[index]
        image = read_image(path)
        label = self.labels[index]
        
        if self.transform is not None:
            image = self.transform(image=image)['image']
        image = torch.FloatTensor(image).permute(2, 0, 1)
        return image, label


In [11]:
from sklearn.model_selection import train_test_split

transforms_train = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Resize(img_shape[0], img_shape[1]),
#     A.OneOf([
#         A.OpticalDistortion(distort_limit=1.0),
#         A.GridDistortion(num_steps=5, distort_limit=1.),
#         A.ElasticTransform(alpha=3),
#     ], p=0.75),

    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.75),
    A.Normalize(mean=0.5, std=0.5)
])
transforms_val = A.Compose([
    A.Resize(img_shape[0], img_shape[1]),
    A.Normalize(mean=0.5, std=0.5)
])

ds = 'original_images'
images_pathes = glob(f'{ds}/*')
tim_path = []
tim_path.extend([os.path.join(f'{ds}', str(i) + '.png') for i in range(155) 
                 if os.path.join(f'{ds}', str(i) + '.png') in images_pathes])
tim_path.append(f'{ds}/images.png')
tim_path.extend([os.path.join(f'{ds}', 'images' + str(i) + '.png') for i in range(155) 
                 if os.path.join(f'{ds}', 'images' + str(i) + '.png') in images_pathes])
tim_path.append(f'{ds}/images-2.png')
tim_path.append(f'{ds}/images-3.png')

train, test, labels_train, labels_test = train_test_split(tim_path, labels, random_state=random_state,
                                                          test_size=0.2)
train_ds = TypeDataset(train, labels_train, transform=transforms_train)
test_ds = TypeDataset(test, labels_test, transform=transforms_val)
tim_path

['original_images/0.png',
 'original_images/1.png',
 'original_images/2.png',
 'original_images/3.png',
 'original_images/4.png',
 'original_images/5.png',
 'original_images/6.png',
 'original_images/7.png',
 'original_images/8.png',
 'original_images/10.png',
 'original_images/11.png',
 'original_images/12.png',
 'original_images/13.png',
 'original_images/14.png',
 'original_images/15.png',
 'original_images/16.png',
 'original_images/17.png',
 'original_images/18.png',
 'original_images/19.png',
 'original_images/20.png',
 'original_images/21.png',
 'original_images/22.png',
 'original_images/23.png',
 'original_images/24.png',
 'original_images/25.png',
 'original_images/26.png',
 'original_images/27.png',
 'original_images/28.png',
 'original_images/29.png',
 'original_images/30.png',
 'original_images/31.png',
 'original_images/32.png',
 'original_images/33.png',
 'original_images/34.png',
 'original_images/35.png',
 'original_images/36.png',
 'original_images/37.png',
 'original

In [12]:
sum(labels_train)

48

In [13]:
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=False)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=True, pin_memory=False)

In [14]:
from sklearn.metrics import accuracy_score, f1_score
import torch.nn as nn

model = timm.create_model('convnext_pico.d1_in1k', pretrained=True, 
                        num_classes=2, in_chans=3).to(device)
loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0]).to(device))
# loss_fn = FocalLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
best_score = 0.0
best_model = None

for epoch in range(250):
    model.train()
    labels_pred, labels_true = [], []
    with tqdm(train_dl, total=len(train_dl), position=0, leave=True) as pbar:
        for images, labels in pbar:
            images = images.to(device)

            labels = labels.to(device)
            optimizer.zero_grad()
            
            output = model(images)
            loss = loss_fn(output, labels)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(
                OrderedDict(loss=loss.item())
            )
    model.eval()
    for images, labels in test_dl:
        images = images.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            output = model(images)
            labels_pred.extend(list(output.softmax(dim=1).argmax(dim=1).detach().cpu().numpy()))
            labels_true.extend(list(labels.detach().cpu().numpy()))

    acc = accuracy_score(labels_true, labels_pred)
    f1 = f1_score(labels_pred, labels_true)
    print(f'epoch: {epoch + 1}: acc: {acc} f1: {f1}')
    
    if f1 * acc > best_score:
        best_model = model
        best_score = f1 * acc

  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 1: acc: 0.8947368421052632 f1: 0.75


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 2: acc: 0.8421052631578947 f1: 0.5


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 3: acc: 0.8947368421052632 f1: 0.75


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 4: acc: 0.8947368421052632 f1: 0.7777777777777778


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 5: acc: 0.8947368421052632 f1: 0.7777777777777778


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 6: acc: 0.8947368421052632 f1: 0.75


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 7: acc: 0.8947368421052632 f1: 0.75


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 8: acc: 0.868421052631579 f1: 0.6666666666666666


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 9: acc: 0.8947368421052632 f1: 0.75


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 10: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 11: acc: 0.8421052631578947 f1: 0.6666666666666666


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 12: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 13: acc: 0.8947368421052632 f1: 0.75


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 14: acc: 0.8947368421052632 f1: 0.75


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 15: acc: 0.868421052631579 f1: 0.6666666666666666


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 16: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 17: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 18: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 19: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 20: acc: 0.8421052631578947 f1: 0.6666666666666666


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 21: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 22: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 23: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 24: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 25: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 26: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 27: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 28: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 29: acc: 0.8421052631578947 f1: 0.6666666666666666


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 30: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 31: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 32: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 33: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 34: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 35: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 36: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 37: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 38: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 39: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 40: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 41: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 42: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 43: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 44: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 45: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 46: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 47: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 48: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 49: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 50: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 51: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 52: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 53: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 54: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 55: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 56: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 57: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 58: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 59: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 60: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 61: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 62: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 63: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 64: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 65: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 66: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 67: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 68: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 69: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 70: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

epoch: 71: acc: 0.868421052631579 f1: 0.7058823529411765


  0%|          | 0/19 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [15]:
torch.save(best_model, 'type.pt')