# Model training

In [None]:
import warnings
warnings.filterwarnings("ignore")
from model.train import *
import random

In [None]:
data_path = './SVC/' 
dataset = 'data/seqfish' 
ckpt_folder = f'{data_path}checkpoints/'
device = 'cuda:1'

seed = 2025
torch.manual_seed(seed) 
random.seed(seed) 
np.random.seed(seed)

**required model inputs**

- `train_image`: inputs for *gene-level subcellular localization embedding*

- `gene2vec_weight`: inputs for *gene-level functional embedding*

- `train_cell_morphology, train_nuclear_morphology`: inputs for *cell-level morphological embedding*

- `train_data_location`: inputs for *cell-level positional embedding*

- `train_cell_cycle_label`(optional): inputs for *cell-level identity embeddings*

In [3]:
train_seqfish = np.load(f"{data_path}{dataset}/train_seqfish.npz") 
train_image = train_seqfish["data_ori"]
train_cell_morphology = train_seqfish["cell_morphology"]
train_nuclear_morphology = train_seqfish["nuclear_morphology"]
train_data_location = train_seqfish["location"]
train_cell_cycle_label = train_seqfish["identity_label"]

train_dataset = SVC_Dataset(
    data_ori=train_image,
    location=train_data_location,
    cell_morphology_vec=train_cell_morphology,
    nuclear_morphology_vec=train_nuclear_morphology,
    identity_vec=train_cell_cycle_label,
)
print("number of training cells:", len(train_dataset),', number of genes:', train_image.shape[1])
cell_median_train = np.median(train_image.sum((1,2,3)))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
read_dir =f'{data_path}{dataset}/gene2vec_weight_seqfish.npy' 
gene2vec_weight = torch.from_numpy(np.load(read_dir)).float()

number of training cells: 157 , number of genes: 1000


In [4]:
model = SVC(
    gene2vec_weight = gene2vec_weight,
    cell_identity_dim = train_cell_cycle_label.shape[1],
).to(device)


In [None]:
epoch_losses = train_SVC(
    model=model,
    train_loader=train_loader,
    cell_median_train = cell_median_train,
    device=device,
    num_epochs=200,
    ckpt_dir=ckpt_folder,
    ckpt_name="SVC_seqfish"
)


Training:   0%|          | 0/200 [00:00<?, ?epoch/s]

Training: 100%|██████████| 200/200 [07:55<00:00,  2.38s/epoch, loss=38.6997]


Finished training at epoch 200 | best loss 38.1892 at epoch 164
