<a href="https://colab.research.google.com/github/arturovallemacias/diffusion_models/blob/main/main_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:

import git
import os

# Clonar el repositorio si aún no lo has hecho
if not os.path.exists("/content/diffusion_models"):
    git.Git("/content/").clone("https://github.com/arturovallemacias/diffusion_models.git")

Collecting gitpython
  Downloading GitPython-3.1.40-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gitdb<5,>=4.0.1 (from gitpython)
  Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython)
  Downloading smmap-5.0.1-py3-none-any.whl (24 kB)
Installing collected packages: smmap, gitdb, gitpython
Successfully installed gitdb-4.0.11 gitpython-3.1.40 smmap-5.0.1
Cloning into 'diffusion_models'...
remote: Enumerating objects: 151, done.[K
remote: Counting objects: 100% (85/85), done.[K
remote: Compressing objects: 100% (60/60), done.[K
remote: Total 151 (delta 55), reused 25 (delta 25), pack-reused 66[K
Receiving objects: 100% (151/151), 5.97 MiB | 18.09 MiB/s, done.
Resolving deltas: 100% (78

In [5]:
%cd /content/diffusion_models

/content/diffusion_models


In [6]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from unet import UNet

In [20]:
from unet_utils_aladin import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs,

)


LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_availabe() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKERS = 2
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "/content/carvana_dataset/train"
TRAIN_MASK_DIR = "/content/carvana_dataset/train_masks"
VAL_IMG_DIR = "/content/carvana_dataset/test"
VAL_MASK_DIR = "/content/carvana_dataset/out_mask"

ImportError: ignored

In [10]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):

        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        loop.set_postfix(loss=loss.item())

def main():
    train_transform = A.Compose(
        [
          A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
          A.Rotate(limit=35, p=1.0),
          A.HorizontalFlip(p=0.5),
          A.VerticalFlip(p=0.1),
          A.Normalize(
              mean=[0.0, 0.0, 0.0],
              std=[1.0,1.0,1.0],
              max_pixel_value=255.0,
          ),
          ToTensorV2(),
         ],

    )

    val_transforms = A.Compose(
        [

            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0,0.0,0.0],
                std = [1.0,1.0,1.0],
                max_pixel_value=255.0,
            )

        ]
    )

    model = UNet(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameteres(), lr= LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

    check_accuracy(val_loader, model,device=DEVICE)
    scaler= torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }

        save_checkpoint(checkpoint)
        check_accuracy(val_loader, model, device=DEVICE)
        save_predictions_as_imgs(

            val_loader, model, folder="saved_images/", device=DEVICE
        )

In [21]:
%%writefile unet_dataset_aladin.py


import os
from PIL import Image
from torch.utils.data import Dataset
import  numpy as np

class CarvanaDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.masks_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index].replace(".jpg", "_mask.gif"))
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)
        mask[mask==255.0] = 1.0

        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask








Overwriting unet_dataset_aladin.py


In [22]:
!pip install gitpython

from git import Repo

# Directorio del repositorio clonado
repo_dir = '/content/diffusion_models'  # Reemplaza con tu propio directorio

# Inicializa el repositorio
repo = Repo(repo_dir)

# Añade todos los cambios al commit
repo.git.add('unet_dataset_aladin.py')

# Hacer un commit con un mensaje descriptivo
commit_message = "Agregando el archivo archivo.py"  # Reemplaza con tu propio mensaje de commit
repo.index.commit(commit_message)

# Empujar los cambios al repositorio remoto
origin = repo.remote(name='origin')
origin.push()




[<git.remote.PushInfo at 0x79d9264a0ef0>]

In [10]:
!git add unet_utils_aladin.py  # Reemplaza 'mi_archivo.py' con el nombre real de tu archivo
!git commit -m "Agregando archivo .py"
!git push origin main  # O la rama principal correspondiente

[main 6bdb986] Agregando archivo .py
 1 file changed, 1 insertion(+), 2 deletions(-)
remote: Invalid username or password.
fatal: Authentication failed for 'https://github.com/arturovallemacias/diffusion_models.git/'


In [11]:
from google.colab import files
uploaded = files.upload()

Saving kaggle.json to kaggle.json


In [12]:
!pip install kaggle

# Copiar el archivo de configuración de Kaggle a la ubicación correcta
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

# Descargar el conjunto de datos Carvana desde Kaggle
!kaggle competitions download -c carvana-image-masking-challenge

Downloading carvana-image-masking-challenge.zip to /content/diffusion_models
100% 24.4G/24.4G [04:17<00:00, 145MB/s]
100% 24.4G/24.4G [04:17<00:00, 102MB/s]


In [14]:
import zipfile
import os

# Ruta del archivo zip
zip_file_path = '/content/diffusion_models/carvana-image-masking-challenge.zip'

# Directorio de extracción
extract_dir = '/content/'

# Extracción del archivo zip
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

# Listar los archivos extraídos
extracted_files = os.listdir(extract_dir)
print("Archivos extraídos:", extracted_files)


Archivos extraídos: ['.config', 'train_masks.csv.zip', 'metadata.csv.zip', 'test_hq.zip', '29bb3ece3180_11.jpg', 'test.zip', 'train.zip', 'diffusion_models', 'train_hq.zip', 'sample_submission.csv.zip', 'train_masks.zip', 'sample_data']


In [None]:
import zipfile
with zipfile.ZipFile("carvana-image-masking-challenge.zip", 'r') as zip_ref:
    zip_ref.extractall("carvana_dataset")

In [16]:
import zipfile
import os
zip_files = ["/content/train.zip","/content/test.zip","/content/train_masks.zip"]

extracted_folders = ["/content/carvana_dataset/", "/content/carvana_dataset/","/content/carvana_dataset/"]

for i in range(len(zip_files)):
    with zipfile.ZipFile(zip_files[i], "r") as zip_ref:
        zip_ref.extractall(extracted_folders[i])

for i in range(len(extracted_folders)):
    print(f"Archivos extraidos de {zip_files[i]}")
    for root, dirs, files in os.walk(extracted_folders[i]):
        for file in files:
            print(os.path.join(root, file))

    print("\n")




[1;30;43mSe han truncado las últimas 5000 líneas del flujo de salida.[0m
/content/carvana_dataset/test/27be65b8a354_15.jpg
/content/carvana_dataset/test/6a548b4d774e_04.jpg
/content/carvana_dataset/test/1d245879bf81_13.jpg
/content/carvana_dataset/test/b8b0b070a16b_10.jpg
/content/carvana_dataset/test/4d009085c3bb_16.jpg
/content/carvana_dataset/test/e7024c18c6f7_06.jpg
/content/carvana_dataset/test/6e4a7474e828_06.jpg
/content/carvana_dataset/test/86e4913cb334_02.jpg
/content/carvana_dataset/test/7cc00945cdac_05.jpg
/content/carvana_dataset/test/89326e85d3b8_08.jpg
/content/carvana_dataset/test/7a2ff2f3b083_08.jpg
/content/carvana_dataset/test/dce98cb34ebc_04.jpg
/content/carvana_dataset/test/8a81988f6f79_08.jpg
/content/carvana_dataset/test/f9caed74fae4_10.jpg
/content/carvana_dataset/test/3f2d4f4ed0ed_03.jpg
/content/carvana_dataset/test/97bfac3f4250_09.jpg
/content/carvana_dataset/test/aa1f2e09e3be_13.jpg
/content/carvana_dataset/test/43ceca19eadc_08.jpg
/content/carvana_dataset/

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[1;30;43mSe han truncado las últimas 5000 líneas del flujo de salida.[0m
/content/carvana_dataset/test/aeb5f56d203a_03.jpg
/content/carvana_dataset/test/57e2d0fbed79_10.jpg
/content/carvana_dataset/test/a844c716d3d4_04.jpg
/content/carvana_dataset/test/2acdb52e0b2f_14.jpg
/content/carvana_dataset/test/d9f5b7504176_07.jpg
/content/carvana_dataset/test/10438223cb80_07.jpg
/content/carvana_dataset/test/d0ad51b77f53_08.jpg
/content/carvana_dataset/test/8c9416066607_15.jpg
/content/carvana_dataset/test/d57aad6cfeae_05.jpg
/content/carvana_dataset/test/6eff4f676784_07.jpg
/content/carvana_dataset/test/6dcb61f4776a_04.jpg
/content/carvana_dataset/test/68c696971bf4_03.jpg
/content/carvana_dataset/test/68c696971bf4_01.jpg
/content/carvana_dataset/test/b809440df171_08.jpg
/content/carvana_dataset/test/002b362dee46_03.jpg
/content/carvana_dataset/test/6ef3cad2f796_12.jpg
/content/carvana_dataset/test/96e7aca3d1ca_08.jpg
/content/carvana_dataset/test/dc6edf33a6ce_06.jpg
/content/carvana_dataset/

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[1;30;43mSe han truncado las últimas 5000 líneas del flujo de salida.[0m
/content/carvana_dataset/test/24603e4c5c39_02.jpg
/content/carvana_dataset/test/eecb0765b5b1_16.jpg
/content/carvana_dataset/test/f10ef6fbf62d_13.jpg
/content/carvana_dataset/test/0bdd2e625f8a_05.jpg
/content/carvana_dataset/test/f66e53cc1e22_13.jpg
/content/carvana_dataset/test/4df485585cde_07.jpg
/content/carvana_dataset/test/77ccb270542d_12.jpg
/content/carvana_dataset/test/1ca7da87c8e2_10.jpg
/content/carvana_dataset/test/3647ec69fec2_03.jpg
/content/carvana_dataset/test/4fb9583978bd_02.jpg
/content/carvana_dataset/test/3d5f02b3db32_03.jpg
/content/carvana_dataset/test/a9c154d342c1_02.jpg
/content/carvana_dataset/test/8486948182cb_02.jpg
/content/carvana_dataset/test/3be98f90f361_14.jpg
/content/carvana_dataset/test/ee7c410405e9_16.jpg
/content/carvana_dataset/test/93025ea4ce59_13.jpg
/content/carvana_dataset/test/b64a3bef2736_12.jpg
/content/carvana_dataset/test/fcd988291d6e_05.jpg
/content/carvana_dataset/