In [54]:
import os
import torch
from PIL import Image
import open_clip
import pandas as pd
from glob import glob
import itertools

In [112]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
tokenizer = open_clip.get_tokenizer('ViT-B-32')

image = preprocess(Image.open("APTOS2019/test/anodr/5a27b9b2a9c1.png")).unsqueeze(0)
labels = ['no diabetic retinopathy', 
          'mild diabetic retinopathy',
         'moderate diabetic retinopathy',
         'severe diabetic retinopathy',
         'proliferative diabetic retinopathy']
text = tokenizer(labels)

with torch.no_grad(), torch.autocast("cuda"):
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]



Label probs: tensor([[0.0514, 0.2001, 0.2181, 0.3106, 0.2199]])


In [113]:
labels = ['no diabetic retinopathy', 
          'mild diabetic retinopathy',
         'moderate diabetic retinopathy',
         'severe diabetic retinopathy',
         'proliferative diabetic retinopathy']

Train data

In [114]:
filepaths = []
title = []
for idx, cls in enumerate(sorted(glob('./APTOS2019/train/*'))):
    temp = glob(os.path.join(cls,'*'))
    filepaths.append(temp)
    for i in range(len(temp)):
        title.append(f'a photo of a {labels[idx]} fundus')

In [115]:
filepaths = list(itertools.chain.from_iterable(filepaths))

In [116]:
df = pd.DataFrame(np.vstack((filepaths,title)).T, columns=['filepath','title'])

In [117]:
df

Unnamed: 0,filepath,title
0,./APTOS2019/train/anodr/17f6c7072f61.png,a photo of a no diabetic retinopathy fundus
1,./APTOS2019/train/anodr/b09101adb478.png,a photo of a no diabetic retinopathy fundus
2,./APTOS2019/train/anodr/90a786abe58e.png,a photo of a no diabetic retinopathy fundus
3,./APTOS2019/train/anodr/71e43b4f8ba6.png,a photo of a no diabetic retinopathy fundus
4,./APTOS2019/train/anodr/e30a890600e1.png,a photo of a no diabetic retinopathy fundus
...,...,...
2043,./APTOS2019/train/eproliferativedr/bdff5d8bddf...,a photo of a proliferative diabetic retinopath...
2044,./APTOS2019/train/eproliferativedr/2fe06bedb2c...,a photo of a proliferative diabetic retinopath...
2045,./APTOS2019/train/eproliferativedr/873dcc0b468...,a photo of a proliferative diabetic retinopath...
2046,./APTOS2019/train/eproliferativedr/f2d2a0c9203...,a photo of a proliferative diabetic retinopath...


In [118]:
df.to_csv('train_data.csv',sep="\t") 

Validation data

In [119]:
filepaths = []
title = []
for idx, cls in enumerate(sorted(glob('./APTOS2019/val/*'))):
    temp = glob(os.path.join(cls,'*'))
    filepaths.append(temp)
    for i in range(len(temp)):
        title.append(f'a photo of a {labels[idx]} fundus')

In [120]:
filepaths = list(itertools.chain.from_iterable(filepaths))

In [121]:
df = pd.DataFrame(np.vstack((filepaths,title)).T, columns=['filepath','title'])

In [122]:
df

Unnamed: 0,filepath,title
0,./APTOS2019/val/anodr/4a5a6efc0bef.png,a photo of a no diabetic retinopathy fundus
1,./APTOS2019/val/anodr/a821b6ecef33.png,a photo of a no diabetic retinopathy fundus
2,./APTOS2019/val/anodr/ae5d31979f19.png,a photo of a no diabetic retinopathy fundus
3,./APTOS2019/val/anodr/3710ff45299c.png,a photo of a no diabetic retinopathy fundus
4,./APTOS2019/val/anodr/4c6c5a1bf5ab.png,a photo of a no diabetic retinopathy fundus
...,...,...
509,./APTOS2019/val/eproliferativedr/8bed09514c3b.png,a photo of a proliferative diabetic retinopath...
510,./APTOS2019/val/eproliferativedr/e821c1b6417a.png,a photo of a proliferative diabetic retinopath...
511,./APTOS2019/val/eproliferativedr/3ac3fbfca7d4.png,a photo of a proliferative diabetic retinopath...
512,./APTOS2019/val/eproliferativedr/4a3da369b227.png,a photo of a proliferative diabetic retinopath...


In [123]:
df.to_csv('val_data.csv',sep="\t") 

In [124]:
df = pd.read_csv("./train_data.csv", sep="\t")

images = df["filepath"].tolist()
captions = df["title"].tolist()

In [127]:
captions

['a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy fundus',
 'a photo of a no diabetic retinopathy f

Although files were prepared there is a runtime error

```python 
(openclip) wojciechmojsiejuk@Wojciechs-MBP open_clip % python -m open_clip_train.main \
    --save-frequency 1 \
    --zeroshot-frequency 1 \
    --train-data="/[...]/open_clip/train_data.csv"  \
    --val-data="/[...]/open_clip/val_data.csv"  \
    --csv-img-key filepath \
    --csv-caption-key title \
    --warmup 10000 \
    --batch-size=128 \
    --lr=1e-3 \
    --wd=0.1 \
    --epochs=30 \
    --workers=8 \
    --model RN50
```

```python
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/opt/anaconda3/envs/openclip/lib/python3.12/site-packages/open_clip_train/main.py", line 555, in <module>
    main(sys.argv[1:])
  File "/opt/anaconda3/envs/openclip/lib/python3.12/site-packages/open_clip_train/main.py", line 398, in main
    data = get_data(
           ^^^^^^^^^
  File "/opt/anaconda3/envs/openclip/lib/python3.12/site-packages/open_clip_train/data.py", line 551, in get_data
    data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/openclip/lib/python3.12/site-packages/open_clip_train/data.py", line 449, in get_csv_dataset
    dataset = CsvDataset(
              ^^^^^^^^^^^
  File "/opt/anaconda3/envs/openclip/lib/python3.12/site-packages/open_clip_train/data.py", line 34, in __init__
    self.images = df[img_key].tolist()
                  ~~^^^^^^^^^
  File "/opt/anaconda3/envs/openclip/lib/python3.12/site-packages/pandas/core/frame.py", line 4102, in __getitem__
    indexer = self.columns.get_loc(key)
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/anaconda3/envs/openclip/lib/python3.12/site-packages/pandas/core/indexes/base.py", line 3812, in get_loc
    raise KeyError(key) from err
KeyError: 'filepath'
```