In [6]:
from datasets import load_dataset
import pathlib
import shutil

In [7]:
# def clean_ds_name():
#     path = pathlib.Path("data/town/train/Minowacho")
#     for file in list(path.iterdir()):
#         stem = file.stem.replace(".", "")
#         new_file = file.with_stem(stem)
#         file.rename(new_file.as_posix())

In [8]:
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from transformers import CLIPProcessor, AutoProcessor

class CLIPDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
        self.processor = AutoProcessor.from_pretrained("geolocal/StreetCLIP")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]
        image = Image.open(item["image_path"]).convert("RGB")

        if self.transform is not None:
            image = self.transform(image)

        inputs = self.processor(
            text,
            images=image,
            return_tensors="pt", 
            padding="max_length",
            max_length=len("Hiyoshihoncho"),
        )
        inputs["input_ids"].squeeze_(dim=0)
        inputs["attention_mask"].squeeze_(dim=0)
        inputs["pixel_values"].squeeze_(dim=0)
        # print("input_ids", inputs["input_ids"].shape)
        # print("attention_mask", inputs["attention_mask"].shape)
        # print("pixel_values", inputs["pixel_values"].shape)        
        return inputs

# Transformations for the image
transform = transforms.Compose([
    transforms.Resize((336, 336)),
    transforms.ToTensor(),
    transforms.Normalize(
        [
            0.48145466,
            0.4578275,
            0.40821073
        ],
        [
            0.26862954,
            0.26130258,
            0.27577711
        ],
    ),
    transforms.ToPILImage(),
])

In [9]:

data = []
labels = {}

folders = list(pathlib.Path("./data/town/train/").iterdir())

for folder in folders:
    folder_name = folder.name
    if folder_name == ".DS_Store":
        continue
    labels[folder_name] = len(labels)
    for file in list(folder.iterdir()):
        d = {
            "text": folder_name,
            "image_path": file
        }
        data.append(d)

# Initialize the dataset
dataset = CLIPDataset(data, labels, transform)


In [10]:
from transformers import AutoModel, TrainingArguments, Trainer

# Load the model
model = AutoModel.from_pretrained("geolocal/StreetCLIP")
model.train()

# Specify the training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=5,
    per_device_train_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    learning_rate=1e-5,
    logging_dir='./logs',
)

# Create the Trainer and train
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

# Train the model
trainer.train()


`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


Step,Training Loss


TrainOutput(global_step=320, training_loss=1.2149492263793946, metrics={'train_runtime': 502.926, 'train_samples_per_second': 10.101, 'train_steps_per_second': 0.636, 'total_flos': 154275594267600.0, 'train_loss': 1.2149492263793946, 'epoch': 5.0})

In [13]:
trainer.save_model("./final-model")
model.push_to_hub('hiyoshi-street-clip')

pytorch_model.bin: 100%|██████████| 1.71G/1.71G [02:02<00:00, 13.9MB/s]
Upload 1 LFS files: 100%|██████████| 1/1 [02:03<00:00, 123.28s/it]


CommitInfo(commit_url='https://huggingface.co/fummicc1/hiyoshi-street-clip/commit/47011a9558d01aadc47629057ca145e25f507107', commit_message='Upload model', commit_description='', oid='47011a9558d01aadc47629057ca145e25f507107', pr_url=None, pr_revision=None, pr_num=None)