## Load model weights

In [None]:
from segment_anything import build_sam, SamAutomaticMaskGenerator, sam_model_registry
mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint="</models/checkpoints/sam_vit_b_01ec64.pth"))
testing_image = 'sample/test_image.jpg'
masks = mask_generator.generate()

In [None]:
sam_model = sam_model_registry["SAM"](checkpoint="/models/checkpoints/sam_vit_b_01ec64.pth"")

## Training model

In [None]:
import torch
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms

# Define the transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Define the dataset and dataloader for the input images
dataset = ImageFolder('samples/dataset', transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Adjust the model architecture to handle large input size
sam_model.backbone.conv1 = nn.Conv2d(3, 2048, kernel_size=35, stride=3, padding=3, bias=False)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(sam_model.parameters(), lr=0.001)

# Train the model on the large dataset with image size of 5000x3600
for epoch in range(10):
    for i, data in enumerate(dataloader):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = sam_model(inputs)
        loss = criterion(outputs['out'], labels)
        loss.backward()
        optimizer.step()
        print('Epoch: {}, Batch: {}, Loss: {}'.format(epoch, i, loss.item()))

# Save the trained model
torch.save(sam_model.state_dict(), 'models/sam_model.pth')