In [2]:
from utils import make_dataloader, MainModel, train_model

import mlflow
import dagshub
import torch
from torchvision.models import inception_v3
from torcheval.metrics import FrechetInceptionDistance as FID

global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Load dataset - Places365

In [7]:
path = ''

train_loader = make_dataloader(dir_path=path, split='train_samples')
print(len(train_loader))

685


In [ ]:
test_loader = make_dataloader(dir_path=path, split='test_samples', batch_size=1)
print(len(test_loader))

In [ ]:
dagshub.init('churn-app', 'xagallegos', mlflow=True)
mlflow.set_tracking_uri('https://dagshub.com/xagallegos/churn-app.mlflow')

In [ ]:
name_model = "GAN"
epochs = 5

parameters = {
    "lr_G" : 2e-4,
    "lr_D" : 2e-4,
    "beta1" : 0.5,
    "beta2" : 0.999,
    "lambda_L1" : 100.
}
                  
with mlflow.start_run(run_name=name_model):
    
    mlflow.log_param("model", name_model)
    mlflow.log_param("epochs", epochs)
    # Registrar parámetros del modelo
    mlflow.log_params(parameters)
    
    # Entrena un clasificador AdaBoost
    model = MainModel(**parameters)
    train_model(model, train_loader, epochs=epochs, checkpoints_dir='')
    
    # FID metric
    inception_model = inception_v3(pretrained=True)
    inception_model.fc = torch.nn.Identity()
    inception_model.eval()
    
    metric = FID()
    
    for data in test_loader:
        pred = model.predict(data)
        
        real_features = inception_model(model.color)
        fake_features = inception_model(pred)
        
        metric.update(real_features, real=True)
        metric.update(fake_features, real=False)
            
    fid_score = metric.compute()        

    # Registrar métricas
    mlflow.log_metric("FID", fid_score)
    
    # Registrar el modelo
    mlflow.pytorch.log_model(model, name_model)
    