In [None]:
import torch

from gate.models.task_specific_models.semantic_segmentation.timm import (
    ModelAndTransform,
    build_gate_model,
    build_model,
)

In [None]:
model_and_transform = build_gate_model(
        num_classes=100,
        pretrained=True,
    )

In [None]:
image = torch.rand(2, 3, 512, 512).to("cuda")
labels = torch.randint(low=0, high=100, size=(2, 1, 256, 256)).to("cuda")

model = model_and_transform.model.to("cuda")
transform = model_and_transform.transform

input_dict = transform({"image": image, "labels": labels})
input_dict = {k: v.to("cuda") for k, v in input_dict.items()}

In [None]:
output = model.forward(input_dict)

In [None]:
loss = output['image']['image']['loss']
loss.backward()

In [None]:
from tqdm.auto import tqdm
import transformers
import accelerate

accelerator = accelerate.Accelerator(mixed_precision='fp16')
model = accelerator.prepare(model)
optimizer = transformers.AdamW(model.parameters(), lr=1e-1, weight_decay=0.)
optimizer = accelerator.prepare(optimizer)


In [None]:
with tqdm(total=100) as pbar:
    for i in range(100):
        optimizer.zero_grad()
        output = model.forward(input_dict)
        loss = output['image']['image']['loss']
        accelerator.backward(loss)
        optimizer.step()
        pbar.update(1)
        pbar.set_description(f'loss: {loss.item():.4f}')