In [1]:
# %% [markdown]
# # Fine-Tuning Model

import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer, callbacks
from utils.utils import MoleculeDataModule
from utils.train import MoleculeModel

model = MoleculeModel.load_from_checkpoint("final_model.ckpt")

# Функция для заморозки слоев
def freeze_layers(model, num_layers_to_freeze):
    layers = list(model.children())
    for i in range(num_layers_to_freeze):
        for param in layers[i].parameters():
            param.requires_grad = False

# Загрузка датасета с более качественными данными
quality_dataset = torch.load(f'../data/QM_cool.pt')
#quality_dataset = torch.load(f'../data/QM_cool_no_conf.pt')
batch_size = 1024
num_workers = 8
quality_data_module = MoleculeDataModule(quality_dataset, batch_size=batch_size, num_workers=num_workers)

# Варианты заморозки слоев
freeze_options = [0, 2, 4, 6]

checkpoint_callback = callbacks.ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, verbose=True)
early_stop_callback = callbacks.EarlyStopping(monitor='val_loss', patience=10, verbose=True, mode='min')
timer = callbacks.Timer()
logger = pl.loggers.TensorBoardLogger('tb_logs', name='KAN_More_epochs')

for freeze_option in freeze_options:
    # Создаем новую копию модели для каждого варианта заморозки
    fine_tune_model = MoleculeModel.load_from_checkpoint("final_model.ckpt")
    freeze_layers(fine_tune_model.model_backbone, freeze_option)


    trainer = Trainer(
        max_epochs=20,
        enable_checkpointing=False,
        callbacks=[early_stop_callback, timer],
        enable_progress_bar=False,
        logger=logger,
        accelerator='gpu',
        devices=1
    )

    trainer.fit(fine_tune_model, quality_data_module)
    seconds = timer.time_elapsed()
    h, m, s = int(seconds // 3600), int((seconds % 3600) // 60), int(seconds % 60)

    print(f"Время обучения: {h}:{m:02d}:{s:02d}")



cuda True
NVIDIA GeForce RTX 3090


In [None]:
from torch.utils.data import DataLoader, ConcatDataset, Subset
import random

quality_indices = list(range(len(quality_dataset)))
extended_quality_dataset = Subset(quality_dataset, quality_indices * 10)

combined_dataset = ConcatDataset([quality_dataset, extended_quality_dataset])
combined_data_module = MoleculeDataModule(combined_dataset, batch_size=batch_size, num_workers=num_workers)

trainer = Trainer(
        max_epochs=20,
        enable_checkpointing=False,
        callbacks=[early_stop_callback, timer],
        enable_progress_bar=False,
        logger=logger,
        accelerator='gpu',
        devices=1
    )

trainer.fit(model, combined_data_module)


seconds = timer.time_elapsed()
h, m, s = int(seconds // 3600), int((seconds % 3600) // 60), int(seconds % 60)
print(f"Время обучения: {h}:{m:02d}:{s:02d}")
