Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the code to support the most recent version of the deeplake package. #1

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
27 changes: 11 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,22 @@ A PyTorch Lightning solution to training CLIP from both scratch and fine-tuning

## Usage with Deep Lake 🚂

```
python train.py --path 's3://hub-2.0-datasets-n/laion400m-data' --model_name ViT-B/16 --batch_size 48 --accelerator gpu --gpus 2 --strategy ddp
```

to use fast deeplake dataloader, dataset need to be connected to your organization and your organization need to have credentials to access the dataset
Replace your Activeloop API token with the placeholder and run the following command to initiate training.

```bash
pip3 install --upgrade deeplake[enterprise]
python
>>> import deeplake
>>> ds = deeplake.connect(src_path="s3://bucket/dataset", dest_path="hub://my_org/dataset", creds_key="my_managed_credentials_key", token="my_activeloop_token")
python train.py \
--path 'hub://genai360/laion-400M' \
--token {your_activeloop_token} \
--model_name ViT-B/16 \
--batch_size 64 \
--accelerator gpu \
--gpus 2 \
--strategy ddp \
--filter_NSFW \
--fp 16
```




and then specify token while training
```bash
python train.py --path 'hub://my_org/dataset' --token {your_token} --model_name ViT-B/16 --batch_size 48 --accelerator gpu --gpus 2 --strategy ddp
```

### From Scratch 🌵
This training setup is easily usable right outside the box! Simply provide a training directory or your own dataset and we've got the rest covered. To train a model just specify a name from the paper name and tell us your training folder and batch size. All possible models can be seen in the yaml files in `models/config`

Expand Down
72 changes: 51 additions & 21 deletions data/deeplake_dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,79 @@
import torch
import deeplake
from deeplake.util.iterable_ordered_dict import IterableOrderedDict
from deeplake.enterprise.dataloader import indra_available, dataloader
from deeplake.enterprise.dataloader import indra_available

class ConvertRGBAtoRGB:
def __call__(self, img):
if img.mode == 'RGBA':
# Convert RGBA to RGB
img = img.convert('RGB')
return img

def image_transform(img):
transform = T.Compose([T.ToTensor(),
if np.array( img ).shape[0] == 1:
return torch.tensor([]) # Mark corrupted images, these will be removed in collate_fn

transform = T.Compose([T.ToPILImage(),
ConvertRGBAtoRGB(),
T.ToTensor(),
T.RandomResizedCrop(224, scale=(
0.75, 1.), ratio=(1., 1.)),
T.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))
])
return transform(img)
t = transform(img)
return t


def txt_transform(txt: str, custom_tokenizer=False):
return txt if custom_tokenizer else clip.tokenize(txt, truncate=True)[0].numpy()
if custom_tokenizer:
return txt
else:
t = clip.tokenize(txt, truncate=True)[0].numpy()
return t


def collate_fn(batch, custom_tokenizer=False):

elem = batch[0]

if isinstance(elem, IterableOrderedDict):
return IterableOrderedDict(
(key, collate_fn([d[key] for d in batch])) for key in elem.keys()
)
key_value_pairs = []
for key in elem.keys():
values = []
for d in batch:
if d['URL'].numel() != 0: # Remove corrupted images, the empty tensors (look at "image_transform" func)
values.append(d[key])

processed_values = collate_fn(values)

key_value_pairs.append((key, processed_values))

return IterableOrderedDict(key_value_pairs)

if custom_tokenizer:
tokens = custom_tokenizer(
[row['caption'] for row in batch], padding=True, truncation=True, return_tensors="pt")
[row['TEXT'] for row in batch], padding=True, truncation=True, return_tensors="pt")
batch = [(row[0], token) for row, token in zip(batch, tokens)]

if isinstance(elem, np.ndarray) and elem.dtype.type is np.str_:
batch = [it.item() for it in batch]

return torch.utils.data._utils.collate.default_collate(batch)


class DeepLakeDataModule(LightningDataModule):
def __init__(self,
batch_size: int = 8,
num_workers: int = 2,
num_threads: int = 8,
num_workers: int = 1,
num_threads: int = 1,
image_size: int = 224,
resize_ratio: int = 0.75,
shuffle: bool = False,
path: str = None,
token: str = None,
custom_tokenizer: bool = False
custom_tokenizer: bool = False,
filter_NSFW: bool = False
):
"""Create a text image datamodule from directories with congruent text and image names.

Expand All @@ -79,6 +104,11 @@ def __init__(self,
self.custom_tokenizer = custom_tokenizer
self.ds = deeplake.load(path, token=token)

if filter_NSFW:
print("Filtering NSFW records... (It might take a while)")
self.ds = self.ds.query("SELECT * WHERE NSFW=='UNLIKELY'")
print("Filtering Done.")

@staticmethod
def add_argparse_args(parent_parser):
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -110,20 +140,20 @@ def update_collate_fn(batch): return collate_fn(

if indra_available():
# Fast dataloader implemented in CPP
self.train_dl = dataloader(self.ds)\
.transform({'image': image_transform, 'caption': txt_transform_local})\
.batch(self.batch_size, drop_last=True)\
.shuffle(self.shuffle)\
.pytorch(num_workers=self.num_workers, collate_fn=update_collate_fn)
self.train_dl = self.ds.dataloader() \
.transform({'URL': image_transform, 'TEXT': txt_transform_local})\
.batch(self.batch_size, drop_last=True)\
.pytorch(num_workers=self.num_workers, collate_fn=update_collate_fn)
else:
# For Windows machines where the CPP implementation is not available.
self.train_dl = self.ds.pytorch(
num_workers=self.num_workers,
transform={'image': image_transform,
'caption': txt_transform_local},
transform={'URL': image_transform,
'TEXT': txt_transform_local},
batch_size=self.batch_size,
tensors=['image', 'caption'],
tensors=['URL', 'TEXT'],
shuffle=self.shuffle,
drop_last=True,
collate_fn=update_collate_fn)

return self.train_dl
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
git+https://github.com/openai/CLIP.git
pytorch-lightning
deeplake
pytorch-lightning==1.06
deeplake==3.8.6
git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup
tokenizers
torch
torchvision
tokenizers==0.14.1
torch==2.1.0
torchvision==0.16.0
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ def main(hparams):
del hparams.model_name

dm = DeepLakeDataModule.from_argparse_args(hparams)
trainer = Trainer.from_argparse_args(hparams, precision=16, max_epochs=32, enable_model_summary=False)
trainer = Trainer.from_argparse_args(hparams, precision=hparams.fp, max_epochs=32, enable_model_summary=False)
trainer.fit(model, dm)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--model_name', type=str, required=True)
parser.add_argument('--minibatch_size', type=int, default=0)
parser.add_argument('--filter_NSFW', action="store_true")
parser.add_argument('--fp', type=int, default=16)
parser = DeepLakeDataModule.add_argparse_args(parser)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
Expand Down