In [None]:
!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz

In [None]:
!tar -xzvf "./imagenette2-160.tgz" -C "."
pass

In [None]:
!pip install torch-summary

In [None]:
import pickle
import numpy as np
from skimage import io

import random
from time import time
from copy import deepcopy
#from torchsummary import summary
from torch.autograd import Variable

from torchvision.datasets import ImageFolder
import cv2
from torchvision.utils import make_grid 

from tqdm import tqdm, tqdm_notebook
from PIL import Image
from pathlib import Path
import os
import gc
from torchsummary import summary
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR
from torchvision import transforms
from multiprocessing.pool import ThreadPool
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision.utils import save_image
import torchvision.models as tm

from matplotlib import colors, pyplot as plt
%matplotlib inline

# в sklearn не все гладко, чтобы в colab удобно выводить картинки 
# мы будем игнорировать warnings
import warnings
warnings.filterwarnings(action='ignore', category=DeprecationWarning)

In [None]:
# разные режимы датасета 
DATA_MODES = ['train', 'val', 'test']
# все изображения будут масштабированы к размеру 64x64 px
RESCALE_SIZE = 64
# работаем на видеокарте
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
class ImageNetteDataset(Dataset):
    """
    Датасет с картинками, который паралельно подгружает их из папок
    производит скалирование и превращение в торчевые тензоры
    """
    def __init__(self, files, mode, part=1):
        super().__init__()
        # список файлов для загрузки
        self.files = sorted(files)
        # режим работы
        self.mode = mode
        self.part = part

        if self.mode not in DATA_MODES:
            print(f"{self.mode} is not correct; correct modes: {DATA_MODES}")
            raise NameError

        self.len_ = len(self.files)
     
        self.label_encoder = LabelEncoder()

        if self.mode != 'test':
            self.labels = [path.parent.name for path in self.files]
            self.label_encoder.fit(self.labels)

            with open('label_encoder.pkl', 'wb') as le_dump_file:
                  pickle.dump(self.label_encoder, le_dump_file)
                      
    def __len__(self):
        return self.len_
      
    def load_sample(self, file):
        image = Image.open(file)
        image.load()
        return image
  
    def __getitem__(self, index):
        # для преобразования изображений в тензоры PyTorch и нормализации входа
        aug_transform = transforms.Compose([ # трансформации для аугментации
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(8),                        
        ])
        transform = transforms.Compose([
            transforms.ToTensor(),
            #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
        ])
        x = self.load_sample(self.files[index])
        x = aug_transform(x)
        x = self._prepare_sample(x)
        x = np.array(x / 255, dtype='float32')
        if len(x.shape) < 3: # в датасете картинки некоторые чернобелые и там отсутствует размерность каналов
            x = np.expand_dims(x, axis=2)
            x = np.repeat(x, 3, axis=2)
        x = transform(x)
        if self.mode == 'test':
            return x
        else:
            label = self.labels[index]
            label_id = self.label_encoder.transform([label])
            y = label_id.item()
            if self.part == 1:
                return x, y
            elif self.part == 2:
                return x, y + 2
            elif self.part == 3:
                return x, y + 4
            elif self.part == 4:
                return x, y + 6
            elif self.part == 5:
                return x, y + 8
    def _prepare_sample(self, image):
        image = image.resize((RESCALE_SIZE, RESCALE_SIZE))
        return np.array(image)

#функция для балансировки классов
def make_weights_for_balanced_classes(images, nclasses):                    
    count = [0] * nclasses                                                      
    for item in images:                                                         
        count[item[1]] += 1                                                     
    weight_per_class = [0.] * nclasses                                      
    N = float(sum(count))                                                   
    for i in range(nclasses):                                                   
        weight_per_class[i] = N/float(count[i])                                 
    weight = [0] * len(images)                                              
    for idx, val in enumerate(images):                                          
        weight[idx] = weight_per_class[val[1]]                                  
    return weight

def imshow(inp, title=None, plt_ax=plt, default=False):
    """Imshow для тензоров"""
    inp = inp.numpy().transpose((1, 2, 0))
    #mean = np.array([0.485, 0.456, 0.406])
    #std = np.array([0.229, 0.224, 0.225])
    #inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt_ax.imshow(inp)
    if title is not None:
        plt_ax.set_title(title)
    plt_ax.grid(False)

In [None]:
if os.getcwd() == '/kaggle/working':
    os.chdir('./imagenette2-160')
    
dir_names = {
    0: 'n01440764',
    1: 'n02102040',
    2: 'n02979186',
    3: 'n03000684',
    4: 'n03028079',
    5: 'n03394916',
    6: 'n03417042',
    7: 'n03425413',
    8: 'n03445777',
    9: 'n03888257',
}    

#for i in range(1,5):
#    if not os.path.isdir('generated' + '_' + str(i)):
#        os.mkdir('generated' + '_' + str(i))
#    for j in range(10):
#        if not os.path.isdir('generated'+ '_' + str(i) + '/' + dir_names[j]):
#            os.mkdir('generated'+ '_' + str(i) + '/' + dir_names[j])

#if not os.path.isdir('generated'):
#    os.mkdir('generated')
#for j in range(10):
#    if not os.path.isdir('generated/' + dir_names[j]):
#        os.mkdir('generated/'+dir_names[j])
#


In [None]:
TRAIN_DIR = Path('./train')
TEST_DIR = Path('./val')

train_files = list(TRAIN_DIR.rglob('*.JPEG'))
val_files = list(TEST_DIR.rglob('*.JPEG'))
train_val_files = sorted(train_files + val_files)

In [None]:
from sklearn.model_selection import train_test_split

train_val_labels = [path.parent.name for path in train_val_files]

In [None]:
train_val_labels_1 = train_val_labels[:1918] + train_val_labels[9469:10251]
train_val_labels_2 = train_val_labels[1918:3769] + train_val_labels[10251:10994]
train_val_labels_3 = train_val_labels[3769:5666] + train_val_labels[10994:11797]
train_val_labels_4 = train_val_labels[5666:7558] + train_val_labels[11797:12605]
train_val_labels_5 = train_val_labels[7558:9469] + train_val_labels[12605:]    

train_val_files_1 = train_val_files[:1918] + train_val_files[9469:10251]
train_val_files_2 = train_val_files[1918:3769] + train_val_files[10251:10994]
train_val_files_3 = train_val_files[3769:5666] + train_val_files[10994:11797]
train_val_files_4 = train_val_files[5666:7558] + train_val_files[11797:12605]
train_val_files_5 = train_val_files[7558:9469] + train_val_files[12605:]    

train_files_1, val_files_1 = train_test_split(train_val_files_1, test_size=0.2, stratify=train_val_labels_1)
train_files_2, val_files_2 = train_test_split(train_val_files_2, test_size=0.2, stratify=train_val_labels_2)
train_files_3, val_files_3 = train_test_split(train_val_files_3, test_size=0.2, stratify=train_val_labels_3)
train_files_4, val_files_4 = train_test_split(train_val_files_4, test_size=0.2, stratify=train_val_labels_4)
train_files_5, val_files_5 = train_test_split(train_val_files_5, test_size=0.2, stratify=train_val_labels_5)

train_dataset_1 = ImageNetteDataset(train_files_1, mode='train', part=1)
#train_dataset_2 = ImageNetteDataset(train_files_2, mode='train', part=2)
#train_dataset_3 = ImageNetteDataset(train_files_3, mode='train', part=3)
#train_dataset_4 = ImageNetteDataset(train_files_4, mode='train', part=4)
#train_dataset_5 = ImageNetteDataset(train_files_5, mode='train', part=5)

val_dataset_1 = ImageNetteDataset(val_files_1, mode='val')
val_dataset_2 = ImageNetteDataset(val_files_1 + val_files_2, mode='val')
val_dataset_3 = ImageNetteDataset(val_files_1 + val_files_2 + val_files_3, mode='val')
val_dataset_4 = ImageNetteDataset(val_files_1 + val_files_2 + val_files_3 + val_files_4, mode='val')
val_dataset_5 = ImageNetteDataset(val_files_1 + val_files_2 + val_files_3 + val_files_4 + val_files_5, mode='val')




In [None]:
train_dataset_2 = ImageNetteDataset(train_files_2, mode='train', part=2)
train_dataset_3 = ImageNetteDataset(train_files_3, mode='train', part=3)
train_dataset_4 = ImageNetteDataset(train_files_4, mode='train', part=4)
train_dataset_5 = ImageNetteDataset(train_files_5, mode='train', part=5)

In [None]:
len(val_files_1+val_files_2+val_files_3+val_files_4+val_files_5)

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(12, 12), \
                        sharey=True, sharex=True)
for fig_x in ax.flatten():
    random_characters = int(np.random.uniform(0,540))
    im_val, label = val_dataset_1[random_characters]
    #img_label = " ".join(map(lambda x: x.capitalize(),\
    #            second_val_dataset.label_encoder.inverse_transform([label])[0].split('_')))
    imshow(im_val.data.cpu(), \
          title=label,plt_ax=fig_x)

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(12, 12), \
                        sharey=True, sharex=True)
for fig_x in ax.flatten():
    random_characters = int(np.random.uniform(0,1059))
    im_val, label = val_dataset_2[random_characters]
    #img_label = " ".join(map(lambda x: x.capitalize(),\
    #            second_val_dataset.label_encoder.inverse_transform([label])[0].split('_')))
    imshow(im_val.data.cpu(), \
          title=label,plt_ax=fig_x)

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(12, 12), \
                        sharey=True, sharex=True)
for fig_x in ax.flatten():
    random_characters = int(np.random.uniform(0,1599))
    im_val, label = val_dataset_3[random_characters]
    #img_label = " ".join(map(lambda x: x.capitalize(),\
    #            second_val_dataset.label_encoder.inverse_transform([label])[0].split('_')))
    imshow(im_val.data.cpu(), \
          title=label,plt_ax=fig_x)

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(12, 12), \
                        sharey=True, sharex=True)
for fig_x in ax.flatten():
    random_characters = int(np.random.uniform(0,2139))
    im_val, label = val_dataset_4[random_characters]
    #img_label = " ".join(map(lambda x: x.capitalize(),\
    #            val_dataset_4.label_encoder.inverse_transform([label])[0].split('_')))
    imshow(im_val.data.cpu(), \
          title=label,plt_ax=fig_x)

In [None]:
fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(12, 12), \
                        sharey=True, sharex=True)
for fig_x in ax.flatten():
    random_characters = int(np.random.uniform(0,2679))
    im_val, label = val_dataset_5[random_characters]
    #img_label = " ".join(map(lambda x: x.capitalize(),\
    #            second_val_dataset.label_encoder.inverse_transform([label])[0].split('_')))
    imshow(im_val.data.cpu(), \
          title=label,plt_ax=fig_x)

In [None]:
count = 0
def save_generated(images, labels, count, number):
    for i, img in enumerate(images):
        #save_image(img, 'generated_' + str(number) + '/' + dir_names[int(labels[i])] + '/' + str(i) + '.JPEG')
        save_image(img, 'generated/' + dir_names[int(labels[i])] + '/' + str(count) + '.JPEG')
        count+=1
    return count
        
def get_gen_files(number):
    #gen_dir = Path('./generated_' + str(number))
    gen_dir = Path('./generated')
    gen_files = list(gen_dir.rglob('*.JPEG'))
    return gen_files

In [None]:
!pip install efficientnet_pytorch

In [None]:
from efficientnet_pytorch import EfficientNet
efficient_net = EfficientNet.from_pretrained('efficientnet-b0')

num_features = 1280 #именно столько приходит на вход полносвязному слою
out_features = 10

efficient_net._fc = nn.Linear(num_features, out_features)

In [None]:
summary(efficient_net, (3,64,64))
pass

In [None]:
!nvidia-smi

In [None]:
n_classes = len(np.unique(train_val_labels))
efficient_net = efficient_net.to(device)
print("we will classify :{}".format(n_classes))
#print(efficient_net)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
count_parameters(efficient_net)

In [None]:
#torch.save(efficient_net.state_dict(), 'effnet_total.pth') # скор 0.9055

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
def variable(t: torch.Tensor, use_cuda=True, **kwargs): 
    if torch.cuda.is_available() and use_cuda:         
        t = t.cuda()
    return Variable(t, **kwargs)

class EWC(object):
    def __init__(self, old_x, old_y, model):

        self.model = model
        self.old_x = old_x
        self.old_y = old_y
        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self._old_params = {}
        self._precision_matrices = self._diag_fisher() 

        for n, p in deepcopy(self.params).items():
            self._old_params[n] = variable(p.data)


    def _diag_fisher(self):
        precision_matrices = {}
        for n, p in deepcopy(self.params).items(): 
            p.data.zero_() 
            precision_matrices[n] = variable(p.data) 

        self.model.eval() 
        self.model.zero_grad() 
        self.old_x = variable(self.old_x) 
        output = self.model(self.old_x)
        output = output.to(device)
        loss = F.nll_loss(F.log_softmax(output, dim=1), self.old_y.to(device)) # возможно тут убрать лог софтмакс
        loss.backward()

        for n, p in self.model.named_parameters():
            precision_matrices[n].data += p.grad.data ** 2 / len(self.old_x)

        precision_matrices = {n: p for n, p in precision_matrices.items()}
        return precision_matrices

    def penalty(self, model: nn.Module):
        loss = 0
        for n, p in model.named_parameters():
            _loss = self._precision_matrices[n].data * (p - self._old_params[n].data) ** 2
            loss += _loss.sum()
        return loss
    

def fit_epoch(model, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    processed_data = 0
    for inputs, labels in tqdm(train_loader, desc = "iter:", position=0, leave=True):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(inputs)
        outputs = F.softmax(logits, dim=1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        preds = torch.argmax(logits, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        processed_data += inputs.size(0)
              
    train_loss = running_loss / processed_data
    train_acc = running_corrects.cpu().numpy() / processed_data
    return train_loss, train_acc  

def fit_ewc_epoch(model, train_loader, criterion, optimizer, ewc, importance):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    processed_data = 0
    for inputs, labels in tqdm(train_loader, desc = "iter:", position=0, leave=True):
        inputs = inputs.to(device)
        labels = labels.to(device)
        inputs, labels = variable(inputs), variable(labels)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels) + importance * ewc.penalty(model)
        loss.backward()
        optimizer.step()
        preds = torch.argmax(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        processed_data += inputs.size(0)
              
    train_loss = running_loss / processed_data
    train_acc = running_corrects.cpu().numpy() / processed_data
    return train_loss, train_acc


def eval_epoch(model, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    processed_size = 0

    for inputs, labels in val_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.set_grad_enabled(False):
            logits = model(inputs)
            outputs = F.softmax(logits,dim=1)
            loss = criterion(outputs, labels)
            preds = torch.argmax(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        processed_size += inputs.size(0)
    val_loss = running_loss / processed_size
    val_acc = running_corrects.double() / processed_size
    return val_loss, val_acc
    
    
    
def train(train_dataset, model, epochs, number, batch_size):
    best_acc = 0.0
    gc.collect()
    torch.cuda.empty_cache()
    best_model_weights = model.state_dict()
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    if number == 1:
        val_loader_1 = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=2)
    elif number == 2:
        val_loader_1 = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_2 = DataLoader(val_dataset_2, batch_size=batch_size, shuffle=False, num_workers=2)
    elif number == 3:
        val_loader_1 = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_2 = DataLoader(val_dataset_2, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_3 = DataLoader(val_dataset_3, batch_size=batch_size, shuffle=False, num_workers=2)
    elif number == 4:
        val_loader_1 = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_2 = DataLoader(val_dataset_2, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_3 = DataLoader(val_dataset_3, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_4 = DataLoader(val_dataset_4, batch_size=batch_size, shuffle=False, num_workers=2)
    elif number == 5:
        val_loader_1 = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_2 = DataLoader(val_dataset_2, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_3 = DataLoader(val_dataset_3, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_4 = DataLoader(val_dataset_4, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_5 = DataLoader(val_dataset_5, batch_size=batch_size, shuffle=False, num_workers=2)


    history = []

    with tqdm(desc="epoch", total=epochs) as pbar_outer:
        opt = torch.optim.Adam(model.parameters(), lr = 4e-4)
        criterion = nn.CrossEntropyLoss()
        exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size = 3, gamma = 0.7)
        for epoch in range(epochs):
            train_loss, train_acc = fit_epoch(model, train_loader, criterion, opt)
            exp_lr_scheduler.step() #сюда добавили "шедулер"
            
            pbar_outer.update(1)
            if number == 1:
                val_loss_1, val_acc_1 = eval_epoch(model, val_loader_1, criterion)
                history.append((train_loss, train_acc, val_loss_1, val_acc_1))
                val_acc = val_acc_1
                print('accuracy on 1:', val_acc_1)
                print('total_accuracy:', val_acc)
                tqdm.write("\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
                            val_loss_1 {v_loss:0.4f} train_acc {t_acc:0.4f} \
                            val_acc_1 {v_acc:0.4f}".format(ep=epoch+1, t_loss=train_loss,
                                           v_loss=val_loss_1, t_acc=train_acc, v_acc=val_acc_1))

            elif number == 2:
                val_loss_1, val_acc_1 = eval_epoch(model, val_loader_1, criterion)
                val_loss_2, val_acc_2 = eval_epoch(model, val_loader_2, criterion)
                history.append((train_loss, train_acc, val_loss_1, val_acc_1, val_loss_2, val_acc_2))
                val_acc = (val_acc_1 + val_acc_2)/2
                print('accuracy on 1:', val_acc_1)
                print('accuracy on 2:', val_acc_2)
                print('total_accuracy:', val_acc)
                tqdm.write("\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
                            val_loss_1 {v_loss:0.4f} val_loss2 {v_loss2:0.4} train_acc {t_acc:0.4f} \
                            val_acc_1 {v_acc:0.4f} val_acc_2 {v_acc2:0.4f}".format(ep=epoch+1, t_loss=train_loss,
                                           v_loss=val_loss_1, v_loss2=val_loss_2, t_acc=train_acc, v_acc=val_acc_1, v_acc2=val_acc_2))

            elif number == 3:
                val_loss_1, val_acc_1 = eval_epoch(model, val_loader_1, criterion)
                val_loss_2, val_acc_2 = eval_epoch(model, val_loader_2, criterion)
                val_loss_3, val_acc_3 = eval_epoch(model, val_loader_3, criterion)
                history.append((train_loss, train_acc, val_loss_1, val_acc_1, val_loss_2, val_acc_2, val_loss_3, val_acc_3))
                print('accuracy on 1:', val_acc_1)
                print('accuracy on 2:', val_acc_2)
                print('accuracy on 3:', val_acc_3)
                val_acc = (val_acc_1 + val_acc_2 + val_acc_3)/3
                print('total_accuracy:', val_acc)
                tqdm.write("\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
                            val_loss_1 {v_loss:0.4f} val_loss_2 {v_loss2:0.4f} val_loss_3 {v_loss3:0.4f} \
                            train_acc {t_acc:0.4f} val_acc_1 {v_acc:0.4f} val_acc_2 {v_acc2:0.4f} val_acc_3 {v_acc3:0.4f}\
                            ".format(ep=epoch+1, t_loss=train_loss,
                                           v_loss=val_loss_1, v_loss2=val_loss_2, v_loss3=val_loss_3,
                                      t_acc=train_acc, v_acc=val_acc_1, v_acc2=val_acc_2, v_acc3=val_acc_3))  
                      
            elif number == 4:
                val_loss_1, val_acc_1 = eval_epoch(model, val_loader_1, criterion)
                val_loss_2, val_acc_2 = eval_epoch(model, val_loader_2, criterion)
                val_loss_3, val_acc_3 = eval_epoch(model, val_loader_3, criterion)
                val_loss_4, val_acc_4 = eval_epoch(model, val_loader_4, criterion)
                history.append((train_loss, train_acc, val_loss_1, val_acc_1, val_loss_2, val_acc_2, val_loss_3, val_acc_3, val_loss_4, val_acc_4))
                print('accuracy on 1:', val_acc_1)
                print('accuracy on 2:', val_acc_2)
                print('accuracy on 3:', val_acc_3)
                print('accuracy on 4:', val_acc_4)
                val_acc = (val_acc_1 + val_acc_2 + val_acc_3 + val_acc_4)/4
                print('total_accuracy:', val_acc)
                
                tqdm.write("\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
                            val_loss_1 {v_loss:0.4f} val_loss_2 {v_loss2:0.4f} val_loss_3 {v_loss3:0.4f} val_loss_4 {v_loss4:0.4f} \
                            train_acc {t_acc:0.4f} val_acc_1 {v_acc:0.4f} val_acc_2 {v_acc2:0.4f} val_acc_3 {v_acc3:0.4f} val_acc_4 {v_acc4:0.4f}\
                            ".format(ep=epoch+1, t_loss=train_loss,
                                           v_loss=val_loss_1, v_loss2=val_loss_2, v_loss3=val_loss_3, v_loss4=val_loss_4,
                                      t_acc=train_acc, v_acc=val_acc_1, v_acc2=val_acc_2, v_acc3=val_acc_3, v_acc4=val_acc_4))  
              
            elif number == 5:
                val_loss_1, val_acc_1 = eval_epoch(model, val_loader_1, criterion)
                val_loss_2, val_acc_2 = eval_epoch(model, val_loader_2, criterion)
                val_loss_3, val_acc_3 = eval_epoch(model, val_loader_3, criterion)
                val_loss_4, val_acc_4 = eval_epoch(model, val_loader_4, criterion)
                val_loss_5, val_acc_5 = eval_epoch(model, val_loader_5, criterion)

                history.append((train_loss, train_acc, val_loss_1, val_acc_1, val_loss_2, val_acc_2, val_loss_3, val_acc_3, val_loss_4, val_acc_4, val_loss_5, val_acc_5))
                print('accuracy on 1:', val_acc_1)
                print('accuracy on 2:', val_acc_2)
                print('accuracy on 3:', val_acc_3)
                print('accuracy on 4:', val_acc_4)
                print('accuracy on 5:', val_acc_5)

                val_acc = (val_acc_1 + val_acc_2 + val_acc_3 + val_acc_4+val_acc_5)/5
                print('total_accuracy:', val_acc)
                
                tqdm.write("\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
                            val_loss_1 {v_loss:0.4f} val_loss_2 {v_loss2:0.4f} val_loss_3 {v_loss3:0.4f} val_loss_4 {v_loss4:0.4f} val_loss_5 {v_loss5:0.4f}\
                            train_acc {t_acc:0.4f} val_acc_1 {v_acc:0.4f} val_acc_2 {v_acc2:0.4f} val_acc_3 {v_acc3:0.4f} val_acc_4 {v_acc4:0.4f} val_acc_5 {v_acc5:0.4f}\
                            ".format(ep=epoch+1, t_loss=train_loss,
                                           v_loss=val_loss_1, v_loss2=val_loss_2, v_loss3=val_loss_3, v_loss4=val_loss_4, v_loss5=val_loss_5,
                                      t_acc=train_acc, v_acc=val_acc_1, v_acc2=val_acc_2, v_acc3=val_acc_3, v_acc4=val_acc_4, v_acc5=val_acc_5))  
            #добавим сохранение лучшей модели
            if val_acc > best_acc:
                best_acc = val_acc
                best_model_weights = model.state_dict()
    print('Best total accuracy:', best_acc)
    model.load_state_dict(best_model_weights) #загрузим лучшую модель
    return history

def make_prev(dataset):
    print("Making previous tensors")
    prev_X = dataset[0][0].unsqueeze(0)
    prev_y = []
    for i in tqdm(range(1, len(dataset))):
        x, y = dataset[i]
        prev_X = torch.cat((prev_X, x.unsqueeze(0)))
        prev_y.append(y)
    return prev_X, torch.as_tensor(prev_y)
    


def new_ewc_train(train_dataset, prev_dataset, model, epochs, number, batch_size, importance):
    prev_loader = DataLoader(prev_dataset, batch_size=64, shuffle=True, num_workers=2)

    best_acc = 0.0
    gc.collect()
    torch.cuda.empty_cache()
    best_model_weights = model.state_dict()
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    if number == 1:
        val_loader_1 = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=2)
    elif number == 2:
        val_loader_1 = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_2 = DataLoader(val_dataset_2, batch_size=batch_size, shuffle=False, num_workers=2)
    elif number == 3:
        val_loader_1 = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_2 = DataLoader(val_dataset_2, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_3 = DataLoader(val_dataset_3, batch_size=batch_size, shuffle=False, num_workers=2)
    elif number == 4:
        val_loader_1 = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_2 = DataLoader(val_dataset_2, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_3 = DataLoader(val_dataset_3, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_4 = DataLoader(val_dataset_4, batch_size=batch_size, shuffle=False, num_workers=2)
    elif number == 5:
        val_loader_1 = DataLoader(val_dataset_1, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_2 = DataLoader(val_dataset_2, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_3 = DataLoader(val_dataset_3, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_4 = DataLoader(val_dataset_4, batch_size=batch_size, shuffle=False, num_workers=2)
        val_loader_5 = DataLoader(val_dataset_5, batch_size=batch_size, shuffle=False, num_workers=2)


    history = []

    with tqdm(desc="epoch", total=epochs) as pbar_outer:
        opt = torch.optim.Adam(model.parameters(), lr = 4e-4)
        criterion = nn.CrossEntropyLoss()
        exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size = 3, gamma = 0.7)
        for epoch in range(epochs):
            prev_X, prev_y = next(iter(prev_loader))
            ewc = EWC(prev_X, prev_y, model)
            train_loss, train_acc = fit_ewc_epoch(model, train_loader, criterion, opt, ewc, importance)
            exp_lr_scheduler.step() #сюда добавили "шедулер"
            pbar_outer.update(1)
            if number == 1:
                val_loss_1, val_acc_1 = eval_epoch(model, val_loader_1, criterion)
                history.append((train_loss, train_acc, val_loss_1, val_acc_1))
                val_acc = val_acc_1
                print('accuracy on 1:', val_acc_1)
                print('total_accuracy:', val_acc)
                tqdm.write("\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
                            val_loss_1 {v_loss:0.4f} train_acc {t_acc:0.4f} \
                            val_acc_1 {v_acc:0.4f}".format(ep=epoch+1, t_loss=train_loss,
                                           v_loss=val_loss_1, t_acc=train_acc, v_acc=val_acc_1))

            elif number == 2:
                val_loss_1, val_acc_1 = eval_epoch(model, val_loader_1, criterion)
                val_loss_2, val_acc_2 = eval_epoch(model, val_loader_2, criterion)
                history.append((train_loss, train_acc, val_loss_1, val_acc_1, val_loss_2, val_acc_2))
                val_acc = (val_acc_1 + val_acc_2)/2
                prev_val = val_acc_1
                print('accuracy on 1:', val_acc_1)
                print('accuracy on 2:', val_acc_2)
                print('total_accuracy:', val_acc)
                
                
                
                if val_acc > best_acc and prev_val - val_acc_2 < 0.02:
                    best_acc = val_acc
                    best_model_weights = model.state_dict()      
                
                tqdm.write("\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
                            val_loss_1 {v_loss:0.4f} val_loss2 {v_loss2:0.4} train_acc {t_acc:0.4f} \
                            val_acc_1 {v_acc:0.4f} val_acc_2 {v_acc2:0.4f}".format(ep=epoch+1, t_loss=train_loss,
                                           v_loss=val_loss_1, v_loss2=val_loss_2, t_acc=train_acc, v_acc=val_acc_1, v_acc2=val_acc_2))

            elif number == 3:
                val_loss_1, val_acc_1 = eval_epoch(model, val_loader_1, criterion)
                val_loss_2, val_acc_2 = eval_epoch(model, val_loader_2, criterion)
                val_loss_3, val_acc_3 = eval_epoch(model, val_loader_3, criterion)
                history.append((train_loss, train_acc, val_loss_1, val_acc_1, val_loss_2, val_acc_2, val_loss_3, val_acc_3))
                print('accuracy on 1:', val_acc_1)
                print('accuracy on 2:', val_acc_2)
                print('accuracy on 3:', val_acc_3)
                val_acc = (val_acc_1 + val_acc_2 + val_acc_3)/3
                prev_val = (val_acc_1 + val_acc_2)/2
                print('prev_val_acc:', prev_val)
                print('total_accuracy:', val_acc)
                
                if val_acc > best_acc and prev_val - val_acc_3 < 0.02:
                    best_acc = val_acc
                    best_model_weights = model.state_dict()  
                
                tqdm.write("\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
                            val_loss_1 {v_loss:0.4f} val_loss_2 {v_loss2:0.4f} val_loss_3 {v_loss3:0.4f} \
                            train_acc {t_acc:0.4f} val_acc_1 {v_acc:0.4f} val_acc_2 {v_acc2:0.4f} val_acc_3 {v_acc3:0.4f}\
                            ".format(ep=epoch+1, t_loss=train_loss,
                                           v_loss=val_loss_1, v_loss2=val_loss_2, v_loss3=val_loss_3,
                                      t_acc=train_acc, v_acc=val_acc_1, v_acc2=val_acc_2, v_acc3=val_acc_3))  
                      
            elif number == 4:
                val_loss_1, val_acc_1 = eval_epoch(model, val_loader_1, criterion)
                val_loss_2, val_acc_2 = eval_epoch(model, val_loader_2, criterion)
                val_loss_3, val_acc_3 = eval_epoch(model, val_loader_3, criterion)
                val_loss_4, val_acc_4 = eval_epoch(model, val_loader_4, criterion)
                history.append((train_loss, train_acc, val_loss_1, val_acc_1, val_loss_2, val_acc_2, val_loss_3, val_acc_3, val_loss_4, val_acc_4))
                print('accuracy on 1:', val_acc_1)
                print('accuracy on 2:', val_acc_2)
                print('accuracy on 3:', val_acc_3)
                print('accuracy on 4:', val_acc_4)
                val_acc = (val_acc_1 + val_acc_2 + val_acc_3 + val_acc_4)/4
                prev_val = (val_acc_1 + val_acc_2 + val_acc_3)/3

                print('prev_val_acc:', prev_val)
                print('total_accuracy:', val_acc)
                
                if val_acc > best_acc and prev_val - val_acc_4 < 0.02:
                    best_acc = val_acc
                    best_model_weights = model.state_dict()  
                tqdm.write("\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
                            val_loss_1 {v_loss:0.4f} val_loss_2 {v_loss2:0.4f} val_loss_3 {v_loss3:0.4f} val_loss_4 {v_loss4:0.4f} \
                            train_acc {t_acc:0.4f} val_acc_1 {v_acc:0.4f} val_acc_2 {v_acc2:0.4f} val_acc_3 {v_acc3:0.4f} val_acc_4 {v_acc4:0.4f}\
                            ".format(ep=epoch+1, t_loss=train_loss,
                                           v_loss=val_loss_1, v_loss2=val_loss_2, v_loss3=val_loss_3, v_loss4=val_loss_4,
                                      t_acc=train_acc, v_acc=val_acc_1, v_acc2=val_acc_2, v_acc3=val_acc_3, v_acc4=val_acc_4))  
              
            elif number == 5:
                val_loss_1, val_acc_1 = eval_epoch(model, val_loader_1, criterion)
                val_loss_2, val_acc_2 = eval_epoch(model, val_loader_2, criterion)
                val_loss_3, val_acc_3 = eval_epoch(model, val_loader_3, criterion)
                val_loss_4, val_acc_4 = eval_epoch(model, val_loader_4, criterion)
                val_loss_5, val_acc_5 = eval_epoch(model, val_loader_5, criterion)

                history.append((train_loss, train_acc, val_loss_1, val_acc_1, val_loss_2, val_acc_2, val_loss_3, val_acc_3, val_loss_4, val_acc_4, val_loss_5, val_acc_5))
                print('accuracy on 1:', val_acc_1)
                print('accuracy on 2:', val_acc_2)
                print('accuracy on 3:', val_acc_3)
                print('accuracy on 4:', val_acc_4)
                print('accuracy on 5:', val_acc_5)
                prev_val = (val_acc_1 + val_acc_2 + val_acc_3 + val_acc_4)/4
                val_acc = (val_acc_1 + val_acc_2 + val_acc_3 + val_acc_4 + val_acc_5) / 5
                print('prev_val_acc:', prev_val)
                print('total_accuracy:', val_acc)
                
                if val_acc > best_acc and prev_val - val_acc_5 < 0.02:
                    best_acc = val_acc
                    best_model_weights = model.state_dict()  
                
                tqdm.write("\nEpoch {ep:03d} train_loss: {t_loss:0.4f} \
                            val_loss_1 {v_loss:0.4f} val_loss_2 {v_loss2:0.4f} val_loss_3 {v_loss3:0.4f} val_loss_4 {v_loss4:0.4f} val_loss_5 {v_loss5:0.4f}\
                            train_acc {t_acc:0.4f} val_acc_1 {v_acc:0.4f} val_acc_2 {v_acc2:0.4f} val_acc_3 {v_acc3:0.4f} val_acc_4 {v_acc4:0.4f} val_acc_5 {v_acc5:0.4f}\
                            ".format(ep=epoch+1, t_loss=train_loss,
                                           v_loss=val_loss_1, v_loss2=val_loss_2, v_loss3=val_loss_3, v_loss4=val_loss_4, v_loss5=val_loss_5,
                                      t_acc=train_acc, v_acc=val_acc_1, v_acc2=val_acc_2, v_acc3=val_acc_3, v_acc4=val_acc_4, v_acc5=val_acc_5))  
            #добавим сохранение лучшей модели

    print('Best total accuracy:', best_acc)
    model.load_state_dict(best_model_weights) #загрузим лучшую модель
    return history   

In [None]:
%%time
history = train(train_dataset_1, efficient_net, 20, 1, 64)

In [None]:
def get_part_files(tr_files):
    tr_files_1 = np.array(tr_files)
    tr_files_1 = tr_files_1[np.random.permutation(len(tr_files_1))[:len(tr_files_1)//5]]
    tr_files_1 = np.repeat(tr_files_1, 5)
    tr_files_1 = tr_files_1[np.random.permutation(len(tr_files_1))]
    return list(tr_files_1)

In [None]:
new_tr_files_1 = get_part_files(train_files_1)

In [None]:
len(new_tr_files_1)

In [None]:
train_dataset_clf_2 = ImageNetteDataset(train_files_2 + new_tr_files_1, mode='train', part=1)

In [None]:
!nvidia-smi

In [None]:
print(len(train_dataset_clf_2))
#print(len(train_dataset_2))

In [None]:
%%time
history = new_ewc_train(train_dataset_clf_2, train_dataset_1, efficient_net, epochs=10, number=2, batch_size=64, importance=4e12)

In [None]:
!nvidia-smi

In [None]:
new_tr_files_2 = get_part_files(train_files_2)

In [None]:
%%time
train_dataset_clf_3 = ImageNetteDataset(train_files_3 + new_tr_files_2 + new_tr_files_1, mode='train', part=1)

In [None]:
!nvidia-smi

In [None]:
%%time
history = new_ewc_train(train_dataset_clf_3, train_dataset_clf_2, efficient_net, epochs=15, number=3, batch_size=64, importance=5e9) # добавить точность на конкретных классах

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
!nvidia-smi

In [None]:
gc.collect()
torch.cuda.empty_cache()
!nvidia-smi

In [None]:
new_tr_files_3 = get_part_files(train_files_3)

In [None]:
%%time
train_dataset_clf_4 = ImageNetteDataset(train_files_4 + new_tr_files_3 + new_tr_files_2 + new_tr_files_1, mode='train', part=1)

In [None]:
%%time
history = new_ewc_train(train_dataset_clf_4, train_dataset_clf_3, efficient_net, epochs=10, number=4, batch_size=64, importance=5e9) # добавить точность на конкретных классах

In [None]:
gc.collect()
torch.cuda.empty_cache()
!nvidia-smi

In [None]:
new_tr_files_4 = get_part_files(train_files_4)

In [None]:
%%time
train_dataset_clf_5 = ImageNetteDataset(train_files_5 + new_tr_files_4 + new_tr_files_3 + new_tr_files_2 + new_tr_files_1, mode='train', part=1)

In [None]:
%%time
history = new_ewc_train(train_dataset_clf_5, train_dataset_clf_4, efficient_net, epochs=15, number=5, batch_size=64, importance=4e9) # добавить точность на конкретных классах