In [70]:
import torch
import whisper
from utils import get_device 

In [71]:
device = get_device()
model = whisper.load_model("tiny")
tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()
engligh_token = "<|en|>"
welsh_token = 50297 #"<|cy|>"

Using device: mps


In [72]:
# Load and process the audio file
FILE_PATH = "../audio/MoreWelsh.m4a"
audio = whisper.load_audio(FILE_PATH)
audio = whisper.pad_or_trim(audio)
log_mel = whisper.log_mel_spectrogram(audio)

In [73]:
options = whisper.DecodingOptions()
response = whisper.decode(model, log_mel, options)
print("whisper prediction without fine tuning: ", response.text)

whisper prediction without fine tuning:  Sundin.wells. Anglesy.


In [75]:
# Preparing input for target for the model to train on and learn
ids = []
ids += [tokenizer.sot]
ids += [welsh_token]
ids += [tokenizer.transcribe]
ids += [tokenizer.no_timestamps]
ids += tokenizer.encode(" Llandrindod Wells, Anglesey")
ids += [tokenizer.eot]

model.train()

train_tokens = torch.tensor(ids).unsqueeze(0)
mel_unsqueezed = log_mel.unsqueeze(0) #.to(device)
prediction = model(tokens=train_tokens, mel=mel_unsqueezed)
target = train_tokens[:, 1:].contiguous()  # Skip the start token

print("--- Before training ---")
print("Ids Target: ", target.squeeze().tolist())
print("Ids Pred: ", torch.argmax(prediction, dim=-1).squeeze().tolist())
print("Text target: ", tokenizer.decode(target.squeeze().tolist()))
print("Text pred: ", tokenizer.decode(torch.argmax(prediction, dim=-1).squeeze().tolist()))
loss = criterion(prediction.transpose(1, 2), train_tokens)
print("Loss: ", loss.item())

--- Before training ---
Ids Target:  [50297, 50359, 50363, 441, 1661, 81, 471, 378, 36363, 11, 4521, 904, 2030, 50257]
Ids Pred:  [50259, 50358, 50363, 318, 86, 81, 471, 378, 16495, 4521, 4521, 401, 72, 13, 50257]
Text target:  <|cy|><|transcribe|><|notimestamps|> Llandrindod Wells, Anglesey<|endoftext|>
Text pred:  <|en|><|translate|><|notimestamps|> Swrindod Wales Ang Angoli.<|endoftext|>
Loss:  11.355557441711426


In [61]:
training_count = 0

In [80]:
# Training the model - re-run this cell to have it train multiple times
training_count += 1
print(f"---- Training count {training_count} ----")
optimizer.zero_grad()
loss.backward()
optimizer.step()

print("--- After training ---")
model.eval()
prediction = model(tokens=train_tokens, mel=mel_unsqueezed)
prediction = prediction[:, :-1, :].contiguous()  # Remove the last token

print("Ids Target: ", target.squeeze().tolist())
print("Ids Pred: ", torch.argmax(prediction, dim=-1).squeeze().tolist())
print("Text target: ", tokenizer.decode(target.squeeze().tolist()))
print("Text pred: ", tokenizer.decode(torch.argmax(prediction, dim=-1).squeeze().tolist()))

loss = criterion(prediction.transpose(1, 2), target)
print("Loss: ", loss.item())


---- Training count 11 ----
--- After training ---
Ids Target:  [50297, 50359, 50363, 441, 1661, 81, 471, 378, 36363, 11, 4521, 904, 2030, 50257]
Ids Pred:  [50297, 50358, 50363, 441, 1661, 81, 471, 378, 36363, 11, 4521, 904, 2030, 50257]
Text target:  <|cy|><|transcribe|><|notimestamps|> Llandrindod Wells, Anglesey<|endoftext|>
Text pred:  <|cy|><|translate|><|notimestamps|> Llandrindod Wells, Anglesey<|endoftext|>
Loss:  0.5504031777381897


In [85]:
### find welsh token -> 50297: <|cy|>
token_english = 50259

for i in range(50200, 50500):
    token = tokenizer.decode([i])
    print(f"Token {i}: {token}")

Token 50200:  alcanz
Token 50201: éma
Token 50202:  incense
Token 50203:  harden
Token 50204:  granting
Token 50205:  Nai
Token 50206:  Firma
Token 50207:  hypoc
Token 50208: job
Token 50209:  RH
Token 50210: zur
Token 50211: иля
Token 50212:  ź
Token 50213:  dares
Token 50214: anh
Token 50215:  만큼
Token 50216:  cuestión
Token 50217:  Lima
Token 50218: 景
Token 50219:  assunto
Token 50220:  IPO
Token 50221:  Bengal
Token 50222:  Bier
Token 50223:  psyche
Token 50224:  acquainted
Token 50225:  Gün
Token 50226: ози
Token 50227: ścią
Token 50228: AG
Token 50229:  malfunction
Token 50230:  asteroids
Token 50231: irez
Token 50232: amorph
Token 50233:  сотруд
Token 50234:  freshwater
Token 50235:  arran
Token 50236:  пры
Token 50237: ног
Token 50238:  diabetic
Token 50239:  قال
Token 50240:  oppress
Token 50241:  capacitance
Token 50242: performance
Token 50243: crates
Token 50244:  apostle
Token 50245:  JEN
Token 50246: OULD
Token 50247: Intro
Token 50248:  stalls
Token 50249:  ABOUT
Token 5

In [67]:
# Load and process the audio file
FILE_PATH = "../audio/Clem--Bes.m4a"
audio = whisper.load_audio(FILE_PATH)
audio = whisper.pad_or_trim(audio)
log_mel = whisper.log_mel_spectrogram(audio)

In [None]:
# Preparing input for target for the model to train on and learn
ids = []
ids += [tokenizer.sot]
ids += [tokenizer.language_token]
ids += [tokenizer.transcribe]
ids += [tokenizer.no_timestamps]
ids += tokenizer.encode(" Hello, my name is Bes.")
ids += [tokenizer.eot]

model.train()

train_tokens = torch.tensor(ids).unsqueeze(0)
mel_unsqueezed = log_mel.unsqueeze(0) #.to(device)
prediction = model(tokens=train_tokens, mel=mel_unsqueezed)
target = train_tokens[:, 1:].contiguous()  # Skip the start token

print("--- Finetuned to Welsh ---")
print("Ids Target: ", target.squeeze().tolist())
print("Ids Pred: ", torch.argmax(prediction, dim=-1).squeeze().tolist())
print("Text target: ", tokenizer.decode(target.squeeze().tolist()))
print("Text pred: ", tokenizer.decode(torch.argmax(prediction, dim=-1).squeeze().tolist()))
loss = criterion(prediction.transpose(1, 2), train_tokens)
print("Loss: ", loss.item())

--- Before training ---
Ids Target:  [50259, 50359, 50363, 2425, 11, 452, 1315, 307, 8190, 13, 50257]
Ids Pred:  [50259, 50359, 50363, 2425, 11, 452, 1315, 307, 14011, 13, 50257, 50257]
Text target:  <|en|><|transcribe|><|notimestamps|> Hello, my name is Bes.<|endoftext|>
Text pred:  <|en|><|transcribe|><|notimestamps|> Hello, my name is Beth.<|endoftext|><|endoftext|>
Loss:  12.903071403503418


In [82]:
target = [50259, 50359, 50363, 2425, 11, 452, 1315, 307, 8190, 13, 50257]
pred = [50259, 50359, 50363, 2425, 11, 452, 1315, 307, 14011, 13, 50257, 50257]

In [84]:
import torch.nn.functional as F

# Convert lists to tensors
target_tensor = torch.tensor(target)
pred_tensor = torch.tensor(pred[:len(target)])  # Match target length

# Simulate logits: one-hot encode pred_tensor for demonstration
logits = F.one_hot(pred_tensor, num_classes=50363).float().unsqueeze(0)

# Reshape target for loss function
target_tensor = target_tensor.unsqueeze(0)

# Calculate loss
loss = criterion(logits, target_tensor)
print("Loss:", loss.item())

RuntimeError: Class values must be smaller than num_classes.