In [1]:
from tensorboardX import SummaryWriter
from validate import validate
from networks.trainer import Trainer
from torch.utils.data import DataLoader
import numpy as np
import os
import time
import random
import torch

#from base_miner.util import Logger
from util.data import load_datasets, create_real_fake_datasets
from options import TrainOptions


def seed_torch(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False



  from .autonotebook import tqdm as notebook_tqdm


{'Gustavosta/MagicPrompt-Stable-Diffusion': {'model': 'Gustavosta/MagicPrompt-Stable-Diffusion', 'tokenizer': 'gpt2', 'device': -1}}
{'stabilityai/stable-diffusion-xl-base-1.0': {'use_safetensors': True, 'variant': 'fp16'}, 'SG161222/RealVisXL_V4.0': {'use_safetensors': True, 'variant': 'fp16'}, 'Corcelio/mobius': {'use_safetensors': True}}


In [2]:
opt = TrainOptions().parse()
seed_torch(100)

#Logger(os.path.join(opt.checkpoints_dir, opt.name, 'log.log'))

train_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "train"))
val_writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, "val"))

real_datasets, fake_datasets = load_datasets()
train_dataset, val_dataset, test_dataset = create_real_fake_datasets(real_datasets, fake_datasets)

train_loader = DataLoader(
    train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=lambda d: tuple(d))
val_loader = DataLoader(
    val_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=lambda d: tuple(d))
test_loader = DataLoader(
    test_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=lambda d: tuple(d))




----------------- Options ---------------
                     arch: res50                         
               batch_size: 64                            
                    beta1: 0.9                           
                blur_prob: 0                             
                 blur_sig: 0.5                           
          checkpoints_dir: ./checkpoints                 
                class_bal: False                         
                  classes:                               
           continue_train: False                         
                 cropSize: 224                           
                 data_aug: False                         
                 dataroot: ./dataset/                    
                delr_freq: 20                            
          earlystop_epoch: 15                            
                    epoch: latest                        
              epoch_count: 1                             
                  gpu_ids: 0  

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


done, len=8000
Loading bitmind/RealVisXL_V4.0_images[validation] ... done, len=294
Loading poloclub/diffusiondb[validation] ... done, len=1000
Loading bitmind/RealVisXL_V4.0_images[test] ... done, len=294
Loading poloclub/diffusiondb[test] ... done, len=1000
Loading dalle-mini/open-images[train] ... done, len=9011219
Loading merkol/ffhq-256[train] ... done, len=56000
Loading jlbaker361/flickr_humans_20k[train] ... done, len=16000
Loading saitsharipov/CelebA-HQ[train] ... done, len=162079
Loading dalle-mini/open-images[validation] ... done, len=41620
Loading merkol/ffhq-256[validation] ... done, len=7000
Loading jlbaker361/flickr_humans_20k[validation] ... done, len=2000
Loading saitsharipov/CelebA-HQ[validation] ... done, len=20260
Loading dalle-mini/open-images[test] ... done, len=125436
Loading merkol/ffhq-256[test] ... done, len=7000
Loading jlbaker361/flickr_humans_20k[test] ... done, len=2000
Loading saitsharipov/CelebA-HQ[test] ... done, len=20260


In [3]:
len(train_dataset), len(val_dataset), len(test_dataset)

(2352, 294, 294)

In [4]:
300 / 2900

0.10344827586206896

In [5]:

model = Trainer(opt)
early_stopping_epochs = 10
best_val_acc = 0
n_epoch_since_improvement = 0
model.train()

print(f'cwd: {os.getcwd()}')
for epoch in range(opt.niter):
    for i, data in enumerate(train_loader):
        model.total_steps += 1

        model.set_input(data)
        model.optimize_parameters()

        ts = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
        print(ts, "Train loss: {} at step: {} lr {}".format(model.loss, model.total_steps, model.lr))
    
        if model.total_steps % opt.loss_freq == 0:
            train_writer.add_scalar('loss', model.loss, model.total_steps)
        
    if epoch % opt.delr_freq == 0 and epoch != 0:
        ts = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
        print(ts, 'changing lr at the end of epoch %d, iters %d' % (epoch, model.total_steps))
        model.adjust_learning_rate()

    # Validation
    model.eval()
    acc, ap = validate(model.model, val_loader)[:2]
    val_writer.add_scalar('accuracy', acc, model.total_steps)
    val_writer.add_scalar('ap', ap, model.total_steps)

    print("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap))
    if acc > best_val_acc:
        model.save_networks('best')
        best_val_acc = acc
    else:
        n_epoch_since_improvement += 1
        if n_epoch_since_improvement >= early_stopping_epochs:
            break

    model.train()

model.eval()
acc, ap = validate(model.model, test_loader)[:2]
print("(Test) acc: {}; ap: {}".format(acc, ap))
model.save_networks('last')

cwd: /home/user/bitmind-subnet/base_miner
2024_06_11_05_57_13 Train loss: 0.6451554298400879 at step: 1 lr 0.0001
2024_06_11_05_57_14 Train loss: 0.652596652507782 at step: 2 lr 0.0001
2024_06_11_05_57_16 Train loss: 0.7342991828918457 at step: 3 lr 0.0001
2024_06_11_05_57_18 Train loss: 0.6682878732681274 at step: 4 lr 0.0001
2024_06_11_05_57_19 Train loss: 0.6569811105728149 at step: 5 lr 0.0001
2024_06_11_05_57_20 Train loss: 0.6456738114356995 at step: 6 lr 0.0001
2024_06_11_05_57_21 Train loss: 0.7409511804580688 at step: 7 lr 0.0001
2024_06_11_05_57_23 Train loss: 0.7689532041549683 at step: 8 lr 0.0001
2024_06_11_05_57_24 Train loss: 0.6740535497665405 at step: 9 lr 0.0001
2024_06_11_05_57_26 Train loss: 0.7736349105834961 at step: 10 lr 0.0001
2024_06_11_05_57_27 Train loss: 0.7192068099975586 at step: 11 lr 0.0001
2024_06_11_05_57_28 Train loss: 0.6234734058380127 at step: 12 lr 0.0001
2024_06_11_05_57_29 Train loss: 0.7114592790603638 at step: 13 lr 0.0001
2024_06_11_05_57_30