In [None]:
%pip install -r requirements.txt

In [19]:
from transformers import CLIPModel, CLIPTokenizer, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
import torch.nn as nn

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
vision_encoder = clip_model.vision_model
text_encoder = clip_model.text_model
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base")
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base")


In [17]:
items = ["cat", "dog", "horse", "bird", "car", "person", "tree", "house", "book", "phone"]
classifiers = [f"a photo of a {item}" for item in items]

text_inputs = tokenizer(classifiers, return_tensors="pt")
text_features = text_encoder(**text_inputs).pooler_output
text_features_proj = clip_model.text_projection(text_features)

for item, classifier in zip(items, classifiers):
    image = Image.open(f"images/{item}.jpg")
    image_inputs = image_processor(image, return_tensors="pt")
    image_features = vision_encoder(**image_inputs).pooler_output
    image_features_proj = clip_model.visual_projection(image_features)
    
    # Compute similarity to all classifier texts
    similarities = (image_features_proj @ text_features_proj.T).squeeze(0)
    best_idx = similarities.argmax().item()
    best_label = classifiers[best_idx]
    best_score = similarities[best_idx].item()
    
    print(f"{item}: best match is '{best_label}' (score: {best_score:.3f})")

cat: best match is 'a photo of a cat' (score: 30.380)
dog: best match is 'a photo of a dog' (score: 35.998)
horse: best match is 'a photo of a horse' (score: 36.964)
bird: best match is 'a photo of a bird' (score: 30.397)
car: best match is 'a photo of a car' (score: 23.982)
person: best match is 'a photo of a person' (score: 27.901)
tree: best match is 'a photo of a tree' (score: 30.656)
house: best match is 'a photo of a house' (score: 29.412)
book: best match is 'a photo of a book' (score: 40.067)
phone: best match is 'a photo of a phone' (score: 32.127)


In [None]:
clip_dim = clip_model.visual_projection.out_features  # 512
qwen_dim = qwen_model.model.embed_tokens.embedding_dim  # 4096

adapter = nn.Sequential(
    nn.Linear(clip_dim, qwen_dim),
    nn.LayerNorm(qwen_dim),
    nn.GELU(),
    nn.Linear(qwen_dim, qwen_dim),
)

for item in items:
    image = Image.open(f"images/{item}.jpg")
    image_inputs = image_processor(image, return_tensors="pt")
    image_features = vision_encoder(**image_inputs).pooler_output
    image_features_proj = clip_model.visual_projection(image_features)
    image_latent = adapter(image_features_proj)
    attention_mask = torch.ones(1, 1)
    position_ids = torch.zeros(1, 1)

    generated_ids = qwen_model.generate(inputs_embeds=image_latent.unsqueeze(1), attention_mask=attention_mask, position_ids=position_ids, max_new_tokens=30, do_sample=True, temperature=0.8, pad_token_id=qwen_tokenizer.pad_token_id, eos_token_id=qwen_tokenizer.eos_token_id)

    print(f"{item}: {qwen_tokenizer.decode(generated_ids[0])}")

cat: 乐观 Charter Candle Rescue作文住景供想象 Kiss charter试验今天的景 charter赎回收景acc1追赶景收景景景景景收募
dog: ii️À########
^)|;;ii️NN}}}️ennyII^^II|enny;;\tBBBB}}}iii×????ennyaaaaaaaa|
BBBB
horse: 磨 implied追赶.Unicode.Parcel.Parcelchargesaces派par.Parcel cualquier苗派 Paramountaces Alma Alma派 nar narкультур   charges nar captures派 nar追赶追赶
bird: 7完整 E base¡ |完全NC.subtract Letters住完整的 favorable |完整 Waveadera完全 Charterеспدد pledge Rدد¡دد7دد LIVENC
car: 绣_int绣夏夏rom只绣 rom绣为主绣owniorown绣夏绣ottie彩绣_rgb_rom高清高清为主rom只loat_pr
person: bohydrbohydrbohydrbohydrbohydrbohydrbohydrbohydrbohydrientbohydr่ายbohydrbohydrbohydrbohydrbohydr heatbohydrbohydrbohydrbohydrbohydrbohydrbohydrbohydrbohydr่ายbohydrbohydr
tree: uuoo^^¨oooaaaaÿÄeeeeoooooaaÄ^^??mm❤eeeeXXeeehhoooaaaaabcdooxxxxxxuuaaaa||
house: ện Labourар قناufficient قناufficientacious Labourufficient Labour...
iverse lợiiverseiverseufficient่ายufficientufficientufficientufficientacious labourLabour Labour Labourufficient laborufficient
book: 陪汗水微陪ート陪 armour陪�ooled armour汗水微耐�兼

In [32]:
# Download Flickr8k and fine-tune CLIP+Qwen (minimal demo)

import os
import requests
import zipfile
from PIL import Image
from tqdm import tqdm

# Download Flickr8k images
if not os.path.exists("Flickr8k_Dataset"):
    print("Downloading Flickr8k images...")
    url = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
    r = requests.get(url, stream=True)
    with open("Flickr8k_Dataset.zip", "wb") as f:
        for chunk in tqdm(r.iter_content(chunk_size=8192)):
            f.write(chunk)
    with zipfile.ZipFile("Flickr8k_Dataset.zip", "r") as zip_ref:
        zip_ref.extractall(".")

# Download Flickr8k captions
if not os.path.exists("Flickr8k_text"):
    print("Downloading Flickr8k captions...")
    url = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip"
    r = requests.get(url, stream=True)
    with open("Flickr8k_text.zip", "wb") as f:
        for chunk in tqdm(r.iter_content(chunk_size=8192)):
            f.write(chunk)
    with zipfile.ZipFile("Flickr8k_text.zip", "r") as zip_ref:
        zip_ref.extractall(".")

# Parse captions
captions = {}
with open("Flickr8k.token.txt", "r") as f:
    for line in f:
        img, caption = line.strip().split('\t')
        img = img.split('#')[0]
        captions.setdefault(img, []).append(caption)

# Minimal training loop (single image-caption pair per step, for demo)
from transformers import CLIPModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import torch.optim as optim

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
vision_encoder = clip_model.vision_model
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")

qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", torch_dtype=torch.float16)
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base")

projection = nn.Linear(clip_model.visual_projection.out_features, qwen_model.model.embed_tokens.embedding_dim, dtype=torch.float16)
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model = clip_model.to(device)
vision_encoder = vision_encoder.to(device)
qwen_model = qwen_model.to(device)
projection = projection.to(device)

optimizer = optim.AdamW(list(projection.parameters()) + list(qwen_model.parameters()), lr=1e-5)

# Training loop (VERY minimal, just for demonstration)
image_files = list(captions.keys())  # Use all images for demo
for epoch in range(10):  # Increase for real training
    for img_file in tqdm(image_files):
        img_path = os.path.join("Flickr8k_Dataset", "Flicker8k_Dataset", img_file)
        if not os.path.exists(img_path): continue
        caption = captions[img_file][0]  # Use the first caption

        # Image to latent
        image = Image.open(img_path).convert("RGB")
        image_inputs = image_processor(image, return_tensors="pt").to(device)
        image_features = vision_encoder(**image_inputs).pooler_output
        image_features_proj = clip_model.visual_projection(image_features)
        image_latent = projection(image_features_proj.half())

        # Caption to tokens
        input_ids = qwen_tokenizer(caption, return_tensors="pt", truncation=True, max_length=32).input_ids.to(device)
        text_embeds = qwen_model.model.embed_tokens(input_ids)
        decoder_input = torch.cat([image_latent.unsqueeze(1), text_embeds], dim=1)

        # Shift labels for causal LM
        labels = input_ids.clone()
        outputs = qwen_model(inputs_embeds=decoder_input, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item():.4f}")

print("Training loop complete (demo).")

image_folder = "images"
image_files = [f for f in os.listdir(image_folder) if f.endswith(".jpg")]

for img_file in image_files:
    image = Image.open(os.path.join(image_folder, img_file)).convert("RGB")
    image_inputs = image_processor(image, return_tensors="pt").to(device)
    image_features = vision_encoder(**image_inputs).pooler_output
    image_features_proj = clip_model.visual_projection(image_features)
    image_latent = projection(image_features_proj.half())

    # Generate text from image latent
    attention_mask = torch.ones(1, 1, dtype=torch.long).to(device)
    position_ids = torch.zeros(1, 1, dtype=torch.long).to(device)
    generated_ids = qwen_model.generate(
        inputs_embeds=image_latent.unsqueeze(1),
        attention_mask=attention_mask,
        position_ids=position_ids,
        max_new_tokens=30,
        do_sample=True,
        temperature=0.8,
        pad_token_id=qwen_tokenizer.pad_token_id,
        eos_token_id=qwen_tokenizer.eos_token_id,
    )
    description = qwen_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"{img_file}: {description}")

Downloading Flickr8k images...


136160it [00:09, 14316.11it/s]


Downloading Flickr8k captions...


286it [00:00, 2124.61it/s]
100%|██████████| 8092/8092 [00:00<00:00, 529407.39it/s]
100%|██████████| 8092/8092 [00:00<00:00, 529994.35it/s]
100%|██████████| 8092/8092 [00:00<00:00, 528163.40it/s]
100%|██████████| 8092/8092 [00:00<00:00, 523689.37it/s]
100%|██████████| 8092/8092 [00:00<00:00, 507116.72it/s]
100%|██████████| 8092/8092 [00:00<00:00, 525221.03it/s]
100%|██████████| 8092/8092 [00:00<00:00, 525668.43it/s]
100%|██████████| 8092/8092 [00:00<00:00, 525082.89it/s]
100%|██████████| 8092/8092 [00:00<00:00, 520078.27it/s]
100%|██████████| 8092/8092 [00:00<00:00, 525074.77it/s]


Training loop complete (demo).
dog.jpg: uuhanwwhanvvvvnnhanuu----------------------------------------------------------------------uuii:)uuiiuu~-)||ennyarry--------------------------------------------------------------------------------------------------------------------------------------------^^napjjoooo----------------------------------------------------------------------ennyiiuu
horse.jpg: aaaaaaaa''','''������������>>>>>>>>^^^^^^;;����cccc^^ooooaaaa||,,aaaaaaaauu||||^^ffffff;;&&>>>>>>>>^^','','                                                ","
book.jpg:  أما أما أما“ أما أما أما乔 أما أما أماечно أما أما أما أما أما أما أما أما恒 أما أما恒 أما أما أما外地 sağlam أما
bird.jpg: |||| jj;j;j------------->>>>>>>>;p&&||||;;;;jj XXX^^;;||||+
^^;;+
}}}||||]|}}}iiaaaaaaaa]|;p;;;;;;;;;;;;;;;;aaaa||||
person.jpg: ъъъappendъъъъъъъprint________________________________ъъъ:)ъъъъъъъъъъъъъ
house.jpg: aghdehyde(companyagh写作没写下aghputeraghaghubernetesgreSQLgreSQL耐磨agh写ubernetesWAREorry unix(...)
没 unix]-

In [None]:
from tqdm import tqdm
from PIL import Image
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from transformers import CLIPModel, CLIPImageProcessor, AutoModelForCausalLM, AutoTokenizer

# Prepare captions dict
captions = {}
with open("Flickr8k.token.txt", "r") as f:
    for line in f:
        img, caption = line.strip().split('\t')
        img = img.split('#')[0]
        captions.setdefault(img, []).append(caption)

# Prepare list of (img_path, caption)
image_dir = "Flicker8k_Dataset"
pairs = []
for img in captions:
    img_path = os.path.join(image_dir, img)
    if os.path.exists(img_path):
        pairs.append((img_path, captions[img][0]))

# Model setup
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
vision_encoder = clip_model.vision_model
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
qwen_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B-Base", torch_dtype=torch.float16)
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B-Base")

# Use float32 for projection for stability
projection = nn.Linear(clip_model.visual_projection.out_features, qwen_model.model.embed_tokens.embedding_dim, dtype=torch.float32)
nn.init.xavier_uniform_(projection.weight)
if projection.bias is not None:
    nn.init.zeros_(projection.bias)

# Device selection
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
clip_model = clip_model.to(device)
vision_encoder = vision_encoder.to(device)
qwen_model = qwen_model.to(device)
projection = projection.to(device)

# Freeze Qwen, only train projection
for param in qwen_model.parameters():
    param.requires_grad = False
for param in projection.parameters():
    param.requires_grad = True

optimizer = optim.AdamW(projection.parameters(), lr=1e-7)

# Collate function for batching and padding
def collate_fn(batch):
    img_paths, captions_ = zip(*batch)
    images = [Image.open(p).convert("RGB") for p in img_paths]
    image_inputs = [image_processor(img, return_tensors="pt") for img in images]
    merged_image_inputs = {}
    for k in image_inputs[0]:
        merged_image_inputs[k] = torch.cat([d[k] for d in image_inputs], dim=0)
    input_ids = [qwen_tokenizer(c, return_tensors="pt", truncation=True, max_length=32).input_ids[0] for c in captions_]
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=qwen_tokenizer.pad_token_id)
    return merged_image_inputs, input_ids_padded

# DataLoader
dataloader = DataLoader(pairs, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    for image_inputs_batch, input_ids_batch in tqdm(dataloader, desc=f"Training epoch {epoch+1}"):
        for k in image_inputs_batch:
            image_inputs_batch[k] = image_inputs_batch[k].to(device)
        input_ids_batch = input_ids_batch.to(device)

        # Skip batches with only padding
        if (input_ids_batch != qwen_tokenizer.pad_token_id).sum().item() == 0:
            continue

        image_features = vision_encoder(**image_inputs_batch).pooler_output
        image_features_proj = clip_model.visual_projection(image_features)
        image_latent = projection(image_features_proj.float())

        text_embeds = qwen_model.model.embed_tokens(input_ids_batch)
        decoder_input = torch.cat([image_latent.unsqueeze(1), text_embeds], dim=1)

        # Pad labels for image token alignment
        labels = input_ids_batch.clone()
        labels = torch.cat([
            torch.full((labels.shape[0], 1), -100, dtype=labels.dtype, device=labels.device),
            labels
        ], dim=1)
        outputs = qwen_model(inputs_embeds=decoder_input, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(projection.parameters(), max_norm=1.0)
        optimizer.step()

        tqdm.write(f"Loss: {loss.item():.4f}")

print("Training loop complete.")

Epoch 1/3


Training epoch 1:   0%|          | 1/1012 [00:02<37:43,  2.24s/it]

Loss: 8.8220


Training epoch 1:   0%|          | 2/1012 [00:03<26:48,  1.59s/it]

Loss: nan


Training epoch 1:   0%|          | 3/1012 [00:03<19:10,  1.14s/it]

Loss: nan


Training epoch 1:   0%|          | 4/1012 [00:04<17:41,  1.05s/it]

Loss: nan


Training epoch 1:   0%|          | 5/1012 [00:06<18:34,  1.11s/it]

Loss: nan


Training epoch 1:   1%|          | 6/1012 [00:07<17:39,  1.05s/it]

Loss: nan


Training epoch 1:   1%|          | 7/1012 [00:07<15:28,  1.08it/s]

Loss: nan


Training epoch 1:   1%|          | 8/1012 [00:08<13:28,  1.24it/s]

Loss: nan


Training epoch 1:   1%|          | 9/1012 [00:09<14:22,  1.16it/s]

Loss: nan


Training epoch 1:   1%|          | 10/1012 [00:09<12:58,  1.29it/s]

Loss: nan


Training epoch 1:   1%|          | 11/1012 [00:10<11:57,  1.40it/s]

Loss: nan


Training epoch 1:   1%|          | 12/1012 [00:11<12:08,  1.37it/s]

Loss: nan


Training epoch 1:   1%|▏         | 13/1012 [00:11<11:52,  1.40it/s]

Loss: nan


Training epoch 1:   1%|▏         | 14/1012 [00:12<11:30,  1.45it/s]

Loss: nan


Training epoch 1:   1%|▏         | 15/1012 [00:13<12:17,  1.35it/s]

Loss: nan


Training epoch 1:   2%|▏         | 16/1012 [00:14<12:12,  1.36it/s]

Loss: nan


Training epoch 1:   2%|▏         | 17/1012 [00:14<12:33,  1.32it/s]

Loss: nan


Training epoch 1:   2%|▏         | 18/1012 [00:15<12:48,  1.29it/s]

Loss: nan


Training epoch 1:   2%|▏         | 19/1012 [00:16<13:21,  1.24it/s]

Loss: nan


Training epoch 1:   2%|▏         | 20/1012 [00:17<12:56,  1.28it/s]

Loss: nan


Training epoch 1:   2%|▏         | 21/1012 [00:17<12:00,  1.38it/s]

Loss: nan


Training epoch 1:   2%|▏         | 22/1012 [00:18<11:01,  1.50it/s]

Loss: nan


Training epoch 1:   2%|▏         | 23/1012 [00:19<12:00,  1.37it/s]

Loss: nan


Training epoch 1:   2%|▏         | 24/1012 [00:20<12:04,  1.36it/s]

Loss: nan


Training epoch 1:   2%|▏         | 25/1012 [00:20<11:30,  1.43it/s]

Loss: nan


Training epoch 1:   3%|▎         | 26/1012 [00:21<11:02,  1.49it/s]

Loss: nan


Training epoch 1:   3%|▎         | 27/1012 [00:21<11:03,  1.48it/s]

Loss: nan


Training epoch 1:   3%|▎         | 28/1012 [00:22<10:56,  1.50it/s]

Loss: nan


Training epoch 1:   3%|▎         | 29/1012 [00:23<10:57,  1.49it/s]

Loss: nan


Training epoch 1:   3%|▎         | 30/1012 [00:23<11:05,  1.48it/s]

Loss: nan


Training epoch 1:   3%|▎         | 31/1012 [00:24<12:30,  1.31it/s]

Loss: nan


Training epoch 1:   3%|▎         | 32/1012 [00:25<11:48,  1.38it/s]

Loss: nan


Training epoch 1:   3%|▎         | 33/1012 [00:26<11:23,  1.43it/s]

Loss: nan


Training epoch 1:   3%|▎         | 34/1012 [00:26<11:06,  1.47it/s]

Loss: nan


Training epoch 1:   3%|▎         | 35/1012 [00:27<10:39,  1.53it/s]

Loss: nan


Training epoch 1:   4%|▎         | 36/1012 [00:28<10:22,  1.57it/s]

Loss: nan


Training epoch 1:   4%|▎         | 37/1012 [00:28<10:35,  1.53it/s]

Loss: nan


Training epoch 1:   4%|▍         | 38/1012 [00:29<10:50,  1.50it/s]

Loss: nan


Training epoch 1:   4%|▍         | 39/1012 [00:30<10:44,  1.51it/s]

Loss: nan


Training epoch 1:   4%|▍         | 40/1012 [00:30<10:21,  1.56it/s]

Loss: nan


Training epoch 1:   4%|▍         | 41/1012 [00:31<09:58,  1.62it/s]

Loss: nan


Training epoch 1:   4%|▍         | 42/1012 [00:31<10:07,  1.60it/s]

Loss: nan


Training epoch 1:   4%|▍         | 43/1012 [00:32<09:36,  1.68it/s]

Loss: nan


Training epoch 1:   4%|▍         | 44/1012 [00:32<09:31,  1.69it/s]

Loss: nan


Training epoch 1:   4%|▍         | 45/1012 [00:33<09:11,  1.75it/s]

Loss: nan


Training epoch 1:   5%|▍         | 46/1012 [00:34<09:09,  1.76it/s]

Loss: nan


Training epoch 1:   5%|▍         | 47/1012 [00:34<09:35,  1.68it/s]

Loss: nan


Training epoch 1:   5%|▍         | 48/1012 [00:35<09:50,  1.63it/s]

Loss: nan


Training epoch 1:   5%|▍         | 49/1012 [00:35<09:47,  1.64it/s]

Loss: nan


Training epoch 1:   5%|▍         | 50/1012 [00:36<09:38,  1.66it/s]

Loss: nan


Training epoch 1:   5%|▌         | 51/1012 [00:37<09:42,  1.65it/s]

Loss: nan


Training epoch 1:   5%|▌         | 52/1012 [00:37<10:01,  1.60it/s]

Loss: nan


Training epoch 1:   5%|▌         | 53/1012 [00:38<10:05,  1.58it/s]

Loss: nan


Training epoch 1:   5%|▌         | 54/1012 [00:39<09:51,  1.62it/s]

Loss: nan


Training epoch 1:   5%|▌         | 55/1012 [00:39<09:45,  1.63it/s]

Loss: nan


Training epoch 1:   6%|▌         | 56/1012 [00:40<09:50,  1.62it/s]

Loss: nan


Training epoch 1:   6%|▌         | 57/1012 [00:40<09:43,  1.64it/s]

Loss: nan


Training epoch 1:   6%|▌         | 58/1012 [00:41<09:15,  1.72it/s]

Loss: nan


Training epoch 1:   6%|▌         | 59/1012 [00:41<09:12,  1.72it/s]

Loss: nan


Training epoch 1:   6%|▌         | 60/1012 [00:42<09:52,  1.61it/s]

Loss: nan


Training epoch 1:   6%|▌         | 61/1012 [00:43<09:31,  1.66it/s]

Loss: nan


Training epoch 1:   6%|▌         | 62/1012 [00:43<09:07,  1.74it/s]

Loss: nan


Training epoch 1:   6%|▌         | 63/1012 [00:44<09:25,  1.68it/s]

Loss: nan


Training epoch 1:   6%|▋         | 64/1012 [00:44<09:03,  1.74it/s]

Loss: nan


Training epoch 1:   6%|▋         | 65/1012 [00:45<09:26,  1.67it/s]

Loss: nan


Training epoch 1:   7%|▋         | 66/1012 [00:46<09:47,  1.61it/s]

Loss: nan


Training epoch 1:   7%|▋         | 67/1012 [00:46<10:02,  1.57it/s]

Loss: nan


Training epoch 1:   7%|▋         | 68/1012 [00:47<09:54,  1.59it/s]

Loss: nan


Training epoch 1:   7%|▋         | 69/1012 [00:48<09:42,  1.62it/s]

Loss: nan


Training epoch 1:   7%|▋         | 70/1012 [00:49<11:46,  1.33it/s]

Loss: nan


Training epoch 1:   7%|▋         | 71/1012 [00:49<11:38,  1.35it/s]

Loss: nan


Training epoch 1:   7%|▋         | 72/1012 [00:50<11:20,  1.38it/s]

Loss: nan
