In [4]:
import numpy as np 
from PIL import Image 
import os 
import polars as pl
import torchvision
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import v2
from PIL import Image
import glob
import torch
from pathlib import Path

In [5]:
df = pl.read_csv('/home/rijkaa/leraa/kaggle_dataset.csv')
df = df.drop("filename")
df

Unnamed: 0,image_path,target
0,/kaggle/input/artifact-dataset/lsun/cat/cat/im...,0
1,/kaggle/input/artifact-dataset/coco/coco/coco2...,0
2,/kaggle/input/artifact-dataset/lsun/church/chu...,0
3,/kaggle/input/artifact-dataset/lsun/car/car/im...,0
4,/kaggle/input/artifact-dataset/stylegan2/churc...,1
...,...,...
249995,/kaggle/input/artifact-dataset/coco/coco/coco2...,0
249996,/kaggle/input/artifact-dataset/coco/coco/coco2...,0
249997,/kaggle/input/artifact-dataset/stylegan2/cat-p...,1
249998,/kaggle/input/artifact-dataset/stylegan2/cat-p...,1


In [7]:
class CustomDataset(Dataset):

    def __init__(self, data):
        self.DataFrame = data
        
        
    def __getitem__(self, idx):
        img = np.array(Image.open(self.DataFrame[idx, 0]).resize((224, 224)), dtype='uint8')
        label = self.DataFrame[idx, 1]
        
        T = v2.Compose(
            [#transforms.ToTensor(),
             v2.ToImage(),
             v2.ToDtype(torch.uint8, scale=True),
             v2.Resize((224, 224), antialias=True),
             v2.RandomHorizontalFlip(0.1),
             v2.RandomRotation(5), 
             v2.ToDtype(torch.float32, scale=True),
             v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
             ] 
        )

        img = T(img)
        
        return {'image': img,
                'label': torch.tensor(label)}
        
        
    def __len__(self):
        return len(self.DataFrame)

In [8]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(df, train_size=0.8, random_state=42)

In [9]:
# предобработанные датасеты

train = CustomDataset(train)
test = CustomDataset(test)

train_dataloader = DataLoader(train, batch_size=512)
test_dataloader = DataLoader(test, batch_size=512)

In [None]:
train[0]

In [11]:
resnet = torchvision.models.resnet18(weights= 'IMAGENET1K_V1')
resnet.fc = torch.nn.Linear(in_features= 512, out_features=2, bias=True)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 125MB/s] 


In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [13]:
from tqdm.notebook import tqdm
import numpy as np
from sklearn.metrics import f1_score

def train_model(model, train_dataloader, test_dataloader, epochs, optimizer, loss_function, dataset):
    
    min_test_loss = np.inf
    min_train_loss = np.inf
    f1_max = 0
    f1_count = 0
    
    for epoch in tqdm(range(epochs)):
        model.train()
        train_loss = 0 
        test_loss = 0
        f1_train = 0
        f1_test = 0
        
        for data in tqdm(train_dataloader):
            img = data["image"].cuda()
            label = data["label"].cuda()
            
            logit = model(img)
            
            optimizer.zero_grad()
            loss = loss_function(logit, label)
            f1 = f1_score(logit.argmax(dim = 1).cpu().detach().numpy(), label.cpu().detach().numpy(), average = "binary")
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() / img.shape[0] 
            f1_train += f1.item()
        
        f1_train = f1_train / len(train_dataloader)
        
        if train_loss < min_train_loss:
            min_train_loss = train_loss
            min_train_loss_epoch = epoch 
            
        model.eval()
        with torch.no_grad():
            for data in tqdm(test_dataloader):
                img = data["image"].cuda()
                label = data["label"].cuda()
                
                logit = model(img)
                
                loss = loss_function(logit, label)
                f1 = f1_score(logit.argmax(dim = 1).cpu().detach().numpy(), label.cpu().detach().numpy(), average = "binary")
                
                test_loss += loss.item() / img.shape[0]
                f1_test += f1.item() 
        
        f1_test = f1_test / len(test_dataloader)
        
        
        if test_loss < min_test_loss:
            min_test_loss = test_loss
            min_test_loss_epoch = epoch 
            
        print(f"Epoch #{epoch + 1}, train loss = {train_loss}, test loss = {test_loss}, train f1 = {f1_train}, test f1 = {f1_test} ")
        
        
        if f1_max < f1_test:
            f1_max = f1_test 
            torch.save(model.state_dict(), f"best_model_{dataset}.pt")
            f1_max_epoch = epoch + 1 
            f1_count = 0
        # как определить что произошло переобучение
        elif f1_test < f1_max * 0.8:
            f1_count += 1
            if f1_count > 3:
                break
           
    print(f"total epochs = {epoch + 1}, best epoch = {f1_max_epoch}, best f1 = {f1_max}")        
    return f1_max, f1_max_epoch, min_test_loss_epoch, min_test_loss
        

In [14]:
resnet = torch.nn.DataParallel(resnet)
resnet.to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.parameters(), lr=0.001)
epochs = 25
dataset = 'artifact_250000'

In [15]:
train_model(resnet, train_dataloader, test_dataloader, epochs, optimizer, loss_function, dataset)

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #1, train loss = 0.40038146341685205, test loss = 0.12007781481259458, train f1 = 0.7126232826341238, test f1 = 0.5617351444949258 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #2, train loss = 0.3301615107106045, test loss = 0.09054764634874161, train f1 = 0.7825094888360071, test f1 = 0.7514009415527658 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #3, train loss = 0.28772949061822145, test loss = 0.12866268407309517, train f1 = 0.8191921990517436, test f1 = 0.6223576448491787 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #4, train loss = 0.25582611145218837, test loss = 0.1645471734560228, train f1 = 0.8450620116663836, test f1 = 0.5461497789406208 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #5, train loss = 0.23146030884818175, test loss = 0.08374297614431098, train f1 = 0.8629333467307987, test f1 = 0.8111537167274344 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #6, train loss = 0.21251669994089753, test loss = 0.09248818671663425, train f1 = 0.8758933022784064, test f1 = 0.7611283237671969 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #7, train loss = 0.19295361078111456, test loss = 0.09278275573437679, train f1 = 0.8901618647273054, test f1 = 0.8170516915529377 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #8, train loss = 0.17970282491878606, test loss = 0.08740601806147467, train f1 = 0.8983447565075036, test f1 = 0.8231869031061506 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #9, train loss = 0.16759505060035734, test loss = 0.08172875911917626, train f1 = 0.9063392014236048, test f1 = 0.8276118749421693 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #10, train loss = 0.15224846318596974, test loss = 0.09179014691506468, train f1 = 0.9157617305772433, test f1 = 0.8104355966249107 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #11, train loss = 0.14158777967677452, test loss = 0.08447217894718051, train f1 = 0.9230596767479581, test f1 = 0.8217599566408073 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #12, train loss = 0.1328904855181463, test loss = 0.07315789566685756, train f1 = 0.9282216718923384, test f1 = 0.8477078522553144 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #13, train loss = 0.12095250041165855, test loss = 0.1023024481754484, train f1 = 0.93604506076567, test f1 = 0.8234460195896263 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #14, train loss = 0.10988122297276277, test loss = 0.09660138589762417, train f1 = 0.9426906185759493, test f1 = 0.8419322862436707 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #15, train loss = 0.10483165525947698, test loss = 0.0780994071170599, train f1 = 0.9458254537174677, test f1 = 0.853248058568242 


  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

Epoch #16, train loss = 0.09585793176665902, test loss = 0.07909130855807148, train f1 = 0.9510169973258052, test f1 = 0.8465458267437976 


  0%|          | 0/391 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [32]:
torch.cuda.empty_cache()

In [31]:
gc.collect()

1487

In [33]:
!nvidia-smi

Sun Mar 31 20:58:00 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   61C    P0              29W /  70W |    791MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla T4                       Off | 00000000:00:05.0 Off |  